diff --git a/botorch/optim/closures/model_closures.py b/botorch/optim/closures/model_closures.py index 8e4c39a0f2..ecbe970bdb 100644 --- a/botorch/optim/closures/model_closures.py +++ b/botorch/optim/closures/model_closures.py @@ -12,8 +12,8 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple from botorch.optim.closures.core import ForwardBackwardClosure -from botorch.optim.utils import TNone from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder +from botorch.utils.types import NoneType from gpytorch.mlls import ( ExactMarginalLogLikelihood, MarginalLogLikelihood, @@ -151,9 +151,9 @@ def closure(**kwargs: Any) -> Tensor: return closure -@GetLossClosure.register(MarginalLogLikelihood, object, object, TNone) +@GetLossClosure.register(MarginalLogLikelihood, object, object, NoneType) def _get_loss_closure_fallback_internal( - mll: MarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any + mll: MarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any ) -> Callable[[], Tensor]: r"""Fallback loss closure with internally managed data.""" @@ -165,9 +165,9 @@ def closure(**kwargs: Any) -> Tensor: return closure -@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, TNone) +@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, NoneType) def _get_loss_closure_exact_internal( - mll: ExactMarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any + mll: ExactMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any ) -> Callable[[], Tensor]: r"""ExactMarginalLogLikelihood loss closure with internally managed data.""" @@ -181,9 +181,9 @@ def closure(**kwargs: Any) -> Tensor: return closure -@GetLossClosure.register(SumMarginalLogLikelihood, object, object, TNone) +@GetLossClosure.register(SumMarginalLogLikelihood, object, object, NoneType) def _get_loss_closure_sum_internal( - mll: SumMarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any + mll: SumMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any ) -> Callable[[], Tensor]: r"""SumMarginalLogLikelihood loss closure with internally managed data.""" diff --git a/botorch/optim/fit.py b/botorch/optim/fit.py index c9925bc74e..098d678666 100644 --- a/botorch/optim/fit.py +++ b/botorch/optim/fit.py @@ -43,12 +43,12 @@ from botorch.optim.utils import ( _filter_kwargs, _get_extra_mll_args, - DEFAULT, get_name_filter, get_parameters_and_bounds, TorchAttr, ) from botorch.optim.utils.model_utils import get_parameters +from botorch.utils.types import DEFAULT from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from gpytorch.settings import fast_computations from numpy import ndarray diff --git a/botorch/optim/utils/__init__.py b/botorch/optim/utils/__init__.py index 552363898a..8ae3cf02cb 100644 --- a/botorch/optim/utils/__init__.py +++ b/botorch/optim/utils/__init__.py @@ -13,8 +13,6 @@ _filter_kwargs, _handle_numerical_errors, _warning_handler_template, - DEFAULT, - TNone, ) from botorch.optim.utils.model_utils import ( _get_extra_mll_args, @@ -40,7 +38,6 @@ "_warning_handler_template", "as_ndarray", "columnwise_clamp", - "DEFAULT", "fix_features", "get_name_filter", "get_bounds_as_ndarray", @@ -53,5 +50,4 @@ "sample_all_priors", "set_tensors_from_ndarray_1d", "TorchAttr", - "TNone", ] diff --git a/botorch/optim/utils/common.py b/botorch/optim/utils/common.py index 773254a183..5cd687b104 100644 --- a/botorch/optim/utils/common.py +++ b/botorch/optim/utils/common.py @@ -16,15 +16,6 @@ import numpy as np from linear_operator.utils.errors import NanError, NotPSDError -TNone = type(None) - - -class _TDefault: - pass - - -DEFAULT = _TDefault() - def _filter_kwargs(function: Callable, **kwargs: Any) -> Any: r"""Filter out kwargs that are not applicable for a given function. diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py index 894a5f4e70..a40b40cb4f 100644 --- a/botorch/optim/utils/numpy_utils.py +++ b/botorch/optim/utils/numpy_utils.py @@ -14,7 +14,7 @@ import numpy as np import torch -from botorch.optim.utils.common import TNone +from botorch.utils.types import NoneType from numpy import ndarray from torch import Tensor @@ -137,7 +137,9 @@ def set_tensors_from_ndarray_1d( def get_bounds_as_ndarray( parameters: Dict[str, Tensor], - bounds: Dict[str, Tuple[Union[float, Tensor, TNone], Union[float, Tensor, TNone]]], + bounds: Dict[ + str, Tuple[Union[float, Tensor, NoneType], Union[float, Tensor, NoneType]] + ], ) -> Optional[np.ndarray]: r"""Helper method for converting bounds into an ndarray. diff --git a/botorch/sampling/pathwise/__init__.py b/botorch/sampling/pathwise/__init__.py new file mode 100644 index 0000000000..b78b774b15 --- /dev/null +++ b/botorch/sampling/pathwise/__init__.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from botorch.sampling.pathwise.features import ( + gen_kernel_features, + KernelEvaluationMap, + KernelFeatureMap, +) +from botorch.sampling.pathwise.paths import ( + GeneralizedLinearPath, + PathDict, + PathList, + SamplePath, +) +from botorch.sampling.pathwise.posterior_samplers import ( + draw_matheron_paths, + MatheronPath, +) +from botorch.sampling.pathwise.prior_samplers import draw_kernel_feature_paths +from botorch.sampling.pathwise.update_strategies import gaussian_update + + +__all__ = [ + "draw_matheron_paths", + "draw_kernel_feature_paths", + "gen_kernel_features", + "gaussian_update", + "GeneralizedLinearPath", + "KernelEvaluationMap", + "KernelFeatureMap", + "MatheronPath", + "SamplePath", + "PathDict", + "PathList", +] diff --git a/botorch/sampling/pathwise/features/__init__.py b/botorch/sampling/pathwise/features/__init__.py new file mode 100644 index 0000000000..9f29581e65 --- /dev/null +++ b/botorch/sampling/pathwise/features/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from botorch.sampling.pathwise.features.generators import gen_kernel_features +from botorch.sampling.pathwise.features.maps import ( + FeatureMap, + KernelEvaluationMap, + KernelFeatureMap, +) + +__all__ = [ + "FeatureMap", + "gen_kernel_features", + "KernelEvaluationMap", + "KernelFeatureMap", +] diff --git a/botorch/sampling/pathwise/features/generators.py b/botorch/sampling/pathwise/features/generators.py new file mode 100644 index 0000000000..42fd30c8d2 --- /dev/null +++ b/botorch/sampling/pathwise/features/generators.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +.. [rahimi2007random] + A. Rahimi and B. Recht. Random features for large-scale kernel machines. + Advances in Neural Information Processing Systems 20 (2007). + +.. [sutherland2015error] + D. J. Sutherland and J. Schneider. On the error of random Fourier features. + arXiv preprint arXiv:1506.02785 (2015). +""" + +from __future__ import annotations + +from typing import Any, Callable + +import torch +from botorch.exceptions.errors import UnsupportedError +from botorch.sampling.pathwise.features.maps import KernelFeatureMap +from botorch.sampling.pathwise.utils import ( + ChainedTransform, + FeatureSelector, + InverseLengthscaleTransform, + OutputscaleTransform, + SineCosineTransform, +) +from botorch.utils.dispatcher import Dispatcher +from botorch.utils.sampling import draw_sobol_normal_samples +from gpytorch import kernels +from gpytorch.kernels.kernel import Kernel +from torch import Size, Tensor +from torch.distributions import Gamma + +TKernelFeatureMapGenerator = Callable[[Kernel, int, int], KernelFeatureMap] +GenKernelFeatures = Dispatcher("gen_kernel_features") + + +def gen_kernel_features( + kernel: kernels.Kernel, + num_inputs: int, + num_outputs: int, + **kwargs: Any, +) -> KernelFeatureMap: + r"""Generates a feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^{n}` such that + :math:`k(x, x') ≈ \phi(x)^{T} \phi(x')`. For stationary kernels :math:`k`, defaults + to the method of random Fourier features. For more details, see [rahimi2007random]_ + and [sutherland2015error]_. + + Args: + kernel: The kernel :math:`k` to be represented via a finite-dim basis. + num_inputs: The number of input features. + num_outputs: The number of kernel features. + """ + return GenKernelFeatures( + kernel, + num_inputs=num_inputs, + num_outputs=num_outputs, + **kwargs, + ) + + +def _gen_fourier_features( + kernel: kernels.Kernel, + weight_generator: Callable[[Size], Tensor], + num_inputs: int, + num_outputs: int, +) -> KernelFeatureMap: + r"""Generate a feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^{2l}` that + approximates a stationary kernel so that :math:`k(x, x') ≈ \phi(x)^\top \phi(x')`. + + Following [sutherland2015error]_, we represent complex exponentials by pairs of + basis functions :math:`\phi_{i}(x) = \sin(x^\top w_{i})` and + :math:`\phi_{i + l} = \cos(x^\top w_{i}). + + Args: + kernel: A stationary kernel :math:`k(x, x') = k(x - x')`. + weight_generator: A callable used to generate weight vectors :math:`w`. + num_inputs: The number of input features. + num_outputs: The number of Fourier features. + """ + if num_outputs % 2: + raise UnsupportedError( + f"Expected an even number of output features, but received {num_outputs=}." + ) + + input_transform = InverseLengthscaleTransform(kernel) + if kernel.active_dims is not None: + num_inputs = len(kernel.active_dims) + input_transform = ChainedTransform( + input_transform, FeatureSelector(indices=kernel.active_dims) + ) + + weight = weight_generator( + Size([kernel.batch_shape.numel() * num_outputs // 2, num_inputs]) + ).reshape(*kernel.batch_shape, num_outputs // 2, num_inputs) + + output_transform = SineCosineTransform( + torch.tensor((2 / num_outputs) ** 0.5, device=kernel.device, dtype=kernel.dtype) + ) + return KernelFeatureMap( + kernel=kernel, + weight=weight, + input_transform=input_transform, + output_transform=output_transform, + ) + + +@GenKernelFeatures.register(kernels.RBFKernel) +def _gen_kernel_features_rbf( + kernel: kernels.RBFKernel, + *, + num_inputs: int, + num_outputs: int, +) -> KernelFeatureMap: + def _weight_generator(shape: Size) -> Tensor: + try: + n, d = shape + except ValueError: + raise UnsupportedError( + f"Expected `shape` to be 2-dimensional, but {len(shape)=}." + ) + + return draw_sobol_normal_samples( + n=n, + d=d, + device=kernel.lengthscale.device, + dtype=kernel.lengthscale.dtype, + ) + + return _gen_fourier_features( + kernel=kernel, + weight_generator=_weight_generator, + num_inputs=num_inputs, + num_outputs=num_outputs, + ) + + +@GenKernelFeatures.register(kernels.MaternKernel) +def _gen_kernel_features_matern( + kernel: kernels.MaternKernel, + *, + num_inputs: int, + num_outputs: int, +) -> KernelFeatureMap: + def _weight_generator(shape: Size) -> Tensor: + try: + n, d = shape + except ValueError: + raise UnsupportedError( + f"Expected `shape` to be 2-dimensional, but {len(shape)=}." + ) + + dtype = kernel.lengthscale.dtype + device = kernel.lengthscale.device + nu = torch.tensor(kernel.nu, device=device, dtype=dtype) + normals = draw_sobol_normal_samples(n=n, d=d, device=device, dtype=dtype) + return Gamma(nu, nu).rsample((n, 1)).rsqrt() * normals + + return _gen_fourier_features( + kernel=kernel, + weight_generator=_weight_generator, + num_inputs=num_inputs, + num_outputs=num_outputs, + ) + + +@GenKernelFeatures.register(kernels.ScaleKernel) +def _gen_kernel_features_scale( + kernel: kernels.ScaleKernel, + *, + num_inputs: int, + num_outputs: int, +) -> KernelFeatureMap: + active_dims = kernel.active_dims + feature_map = gen_kernel_features( + kernel.base_kernel, + num_inputs=num_inputs if active_dims is None else len(active_dims), + num_outputs=num_outputs, + ) + + if active_dims is not None and active_dims is not kernel.base_kernel.active_dims: + feature_map.input_transform = ChainedTransform( + feature_map.input_transform, FeatureSelector(indices=active_dims) + ) + + feature_map.output_transform = ChainedTransform( + OutputscaleTransform(kernel), feature_map.output_transform + ) + return feature_map diff --git a/botorch/sampling/pathwise/features/maps.py b/botorch/sampling/pathwise/features/maps.py new file mode 100644 index 0000000000..2600ba255d --- /dev/null +++ b/botorch/sampling/pathwise/features/maps.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Optional, Union + +import torch +from botorch.sampling.pathwise.utils import ( + TInputTransform, + TOutputTransform, + TransformedModuleMixin, +) +from gpytorch.kernels import Kernel +from linear_operator.operators import LinearOperator +from torch import Size, Tensor +from torch.nn import Module + + +class FeatureMap(TransformedModuleMixin, Module): + num_outputs: int + batch_shape: Size + input_transform: Optional[TInputTransform] + output_transform: Optional[TOutputTransform] + + +class KernelEvaluationMap(FeatureMap): + r"""A feature map defined by centering a kernel at a set of points.""" + + def __init__( + self, + kernel: Kernel, + points: Tensor, + input_transform: Optional[TInputTransform] = None, + output_transform: Optional[TOutputTransform] = None, + ) -> None: + r"""Initializes a KernelEvaluationMap instance: + + .. code-block:: text + + feature_map(x) = output_transform(kernel(input_transform(x), points)). + + Args: + kernel: The kernel :math:`k` used to define the feature map. + points: A tensor passed as the kernel's second argument. + input_transform: An optional input transform for the module. + output_transform: An optional output transform for the module. + """ + try: + torch.broadcast_shapes(points.shape[:-2], kernel.batch_shape) + except RuntimeError: + raise RuntimeError( + f"Shape mismatch: {points.shape=}, but {kernel.batch_shape=}." + ) + + super().__init__() + self.kernel = kernel + self.points = points + self.input_transform = input_transform + self.output_transform = output_transform + + def forward(self, x: Tensor) -> Union[Tensor, LinearOperator]: + return self.kernel(x, self.points) + + @property + def num_outputs(self) -> int: + if self.output_transform is None: + return self.points.shape[-1] + + canary = torch.empty( + 1, self.points.shape[-1], device=self.points.device, dtype=self.points.dtype + ) + return self.output_transform(canary).shape[-1] + + @property + def batch_shape(self) -> Size: + return self.kernel.batch_shape + + +class KernelFeatureMap(FeatureMap): + r"""Representation of a kernel :math:`k: \mathcal{X}^2 \to \mathbb{R}` as an + n-dimensional feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^n` satisfying: + :math:`k(x, x') ≈ \phi(x)^\top \phi(x')`. + """ + + def __init__( + self, + kernel: Kernel, + weight: Tensor, + bias: Optional[Tensor] = None, + input_transform: Optional[TInputTransform] = None, + output_transform: Optional[TOutputTransform] = None, + ) -> None: + r"""Initializes a KernelFeatureMap instance: + + .. code-block:: text + + feature_map(x) = output_transform(input_transform(x)^{T} weight + bias). + + Args: + kernel: The kernel :math:`k` used to define the feature map. + weight: A tensor of weights used to linearly combine the module's inputs. + bias: A tensor of biases to be added to the linearly combined inputs. + input_transform: An optional input transform for the module. + output_transform: An optional output transform for the module. + """ + super().__init__() + self.kernel = kernel + self.weight = weight + self.bias = bias + self.input_transform = input_transform + self.output_transform = output_transform + + def forward(self, x: Tensor) -> Tensor: + out = x @ self.weight.transpose(-2, -1) + return out if self.bias is None else out + self.bias + + @property + def num_outputs(self) -> int: + if self.output_transform is None: + return self.weight.shape[-2] + + canary = torch.empty( + self.weight.shape[-2], device=self.weight.device, dtype=self.weight.dtype + ) + return self.output_transform(canary).shape[-1] + + @property + def batch_shape(self) -> Size: + return self.kernel.batch_shape diff --git a/botorch/sampling/pathwise/paths.py b/botorch/sampling/pathwise/paths.py new file mode 100644 index 0000000000..84a7917fa4 --- /dev/null +++ b/botorch/sampling/pathwise/paths.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from botorch.exceptions.errors import UnsupportedError +from botorch.sampling.pathwise.features import FeatureMap +from botorch.sampling.pathwise.utils import ( + TInputTransform, + TOutputTransform, + TransformedModuleMixin, +) +from torch import Tensor +from torch.nn import Module, ModuleDict, ModuleList, Parameter + + +class SamplePath(ABC, TransformedModuleMixin, Module): + r"""Abstract base class for Botorch sample paths.""" + + +class PathDict(SamplePath): + r"""A dictionary of SamplePaths.""" + + def __init__( + self, + paths: Optional[Mapping[str, SamplePath]] = None, + join: Optional[Callable[[List[Tensor]], Tensor]] = None, + input_transform: Optional[TInputTransform] = None, + output_transform: Optional[TOutputTransform] = None, + ) -> None: + r"""Initializes a PathDict instance. + + Args: + paths: An optional mapping of strings to sample paths. + join: An optional callable used to combine each path's outputs. + input_transform: An optional input transform for the module. + output_transform: An optional output transform for the module. + """ + if join is None and output_transform is not None: + raise UnsupportedError("Output transforms must be preceded by a join rule.") + + super().__init__() + self.join = join + self.input_transform = input_transform + self.output_transform = output_transform + self.paths = ( + paths + if isinstance(paths, ModuleDict) + else ModuleDict({} if paths is None else paths) + ) + + def forward(self, x: Tensor, **kwargs: Any) -> Union[Tensor, Dict[str, Tensor]]: + out = [path(x, **kwargs) for path in self.paths.values()] + return dict(zip(self.paths, out)) if self.join is None else self.join(out) + + def items(self) -> Iterable[Tuple[str, SamplePath]]: + return self.paths.items() + + def keys(self) -> Iterable[str]: + return self.paths.keys() + + def values(self) -> Iterable[SamplePath]: + return self.paths.values() + + def __len__(self) -> int: + return len(self.paths) + + def __iter__(self) -> Iterator[SamplePath]: + yield from self.paths + + def __delitem__(self, key: str) -> None: + del self.paths[key] + + def __getitem__(self, key: str) -> SamplePath: + return self.paths[key] + + def __setitem__(self, key: str, val: SamplePath) -> None: + self.paths[key] = val + + +class PathList(SamplePath): + r"""A list of SamplePaths.""" + + def __init__( + self, + paths: Optional[Iterable[SamplePath]] = None, + join: Optional[Callable[[List[Tensor]], Tensor]] = None, + input_transform: Optional[TInputTransform] = None, + output_transform: Optional[TOutputTransform] = None, + ) -> None: + r"""Initializes a PathList instance. + + Args: + paths: An optional iterable of sample paths. + join: An optional callable used to combine each path's outputs. + input_transform: An optional input transform for the module. + output_transform: An optional output transform for the module. + """ + + if join is None and output_transform is not None: + raise UnsupportedError("Output transforms must be preceded by a join rule.") + + super().__init__() + self.join = join + self.input_transform = input_transform + self.output_transform = output_transform + self.paths = ( + paths + if isinstance(paths, ModuleList) + else ModuleList({} if paths is None else paths) + ) + + def forward(self, x: Tensor, **kwargs: Any) -> Union[Tensor, List[Tensor]]: + out = [path(x, **kwargs) for path in self.paths] + return out if self.join is None else self.join(out) + + def __len__(self) -> int: + return len(self.paths) + + def __iter__(self) -> Iterator[SamplePath]: + yield from self.paths + + def __delitem__(self, key: int) -> None: + del self.paths[key] + + def __getitem__(self, key: int) -> SamplePath: + return self.paths[key] + + def __setitem__(self, key: int, val: SamplePath) -> None: + self.paths[key] = val + + +class GeneralizedLinearPath(SamplePath): + r"""A sample path in the form of a generalized linear model.""" + + def __init__( + self, + feature_map: FeatureMap, + weight: Union[Parameter, Tensor], + bias_module: Optional[Module] = None, + input_transform: Optional[TInputTransform] = None, + output_transform: Optional[TOutputTransform] = None, + ): + r"""Initializes a GeneralizedLinearPath instance. + + .. code-block:: text + + path(x) = output_transform(bias_module(z) + feature_map(z)^T weight), + where z = input_transform(x). + + Args: + feature_map: A map used to featurize the module's inputs. + weight: A tensor of weights used to combine input features. + bias_module: An optional module used to define additive offsets. + input_transform: An optional input transform for the module. + output_transform: An optional output transform for the module. + """ + super().__init__() + self.feature_map = feature_map + self.weight = weight + self.bias_module = bias_module + self.input_transform = input_transform + self.output_transform = output_transform + + def forward(self, x: Tensor, **kwargs) -> Tensor: + feat = self.feature_map(x, **kwargs) + out = (feat @ self.weight.unsqueeze(-1)).squeeze(-1) + return out if self.bias_module is None else out + self.bias_module(x) diff --git a/botorch/sampling/pathwise/posterior_samplers.py b/botorch/sampling/pathwise/posterior_samplers.py new file mode 100644 index 0000000000..04223d63af --- /dev/null +++ b/botorch/sampling/pathwise/posterior_samplers.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +.. [wilson2020sampling] + J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Efficiently + sampling functions from Gaussian process posteriors. International Conference on + Machine Learning (2020). + +.. [wilson2021pathwise] + J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Pathwise + Conditioning of Gaussian Processes. Journal of Machine Learning Research (2021). +""" + +from __future__ import annotations + +from typing import Any, Optional, Union + +from botorch.models.approximate_gp import ApproximateGPyTorchModel +from botorch.models.model_list_gp_regression import ModelListGP +from botorch.sampling.pathwise.paths import PathDict, PathList, SamplePath +from botorch.sampling.pathwise.prior_samplers import ( + draw_kernel_feature_paths, + TPathwisePriorSampler, +) +from botorch.sampling.pathwise.update_strategies import gaussian_update, TPathwiseUpdate +from botorch.sampling.pathwise.utils import ( + get_output_transform, + get_train_inputs, + get_train_targets, + TInputTransform, + TOutputTransform, +) +from botorch.utils.context_managers import delattr_ctx +from botorch.utils.dispatcher import Dispatcher +from gpytorch.models import ApproximateGP, ExactGP, GP +from torch import Size + +DrawMatheronPaths = Dispatcher("draw_matheron_paths") + + +class MatheronPath(PathDict): + r"""Represents function draws from a GP posterior via Matheron's rule: + + .. code-block:: text + + "Prior path" + v + (f | y)(·) = f(·) + Cov(f(·), y) Cov(y, y)^{-1} (y - f(X) - ε), + \_______________________________________/ + v + "Update path" + + where `=` denotes equality in distribution, :math:`f \sim GP(0, k)`, + :math:`y \sim N(f(X), \Sigma)`, and :math:`\epsilon \sim N(0, \Sigma)`. + For more information, see [wilson2020sampling]_ and [wilson2021pathwise]_. + """ + + def __init__( + self, + prior_paths: SamplePath, + update_paths: SamplePath, + input_transform: Optional[TInputTransform] = None, + output_transform: Optional[TOutputTransform] = None, + ) -> None: + r"""Initializes a MatheronPath instance. + + Args: + prior_paths: Sample paths used to represent the prior. + update_paths: Sample paths used to represent the data. + input_transform: An optional input transform for the module. + output_transform: An optional output transform for the module. + """ + + super().__init__( + join=sum, + paths={"prior_paths": prior_paths, "update_paths": update_paths}, + input_transform=input_transform, + output_transform=output_transform, + ) + + +def draw_matheron_paths( + model: GP, + sample_shape: Size, + prior_sampler: TPathwisePriorSampler = draw_kernel_feature_paths, + update_strategy: TPathwiseUpdate = gaussian_update, + **kwargs: Any, +) -> MatheronPath: + r"""Generates function draws from (an approximate) Gaussian process prior. + + When evaluted, sample paths produced by this method return Tensors with dimensions + `sample_dims x batch_dims x [joint_dim]`, where `joint_dim` denotes the penultimate + dimension of the input tensor. For multioutput models, outputs are returned as the + final batch dimension. + + Args: + model: Gaussian process whose posterior is to be sampled. + sample_shape: Sizes of sample dimensions. + prior_sample: A callable that takes a model and a sample shape and returns + a set of sample paths representing the prior. + update_strategy: A callable that takes a model and a tensor of prior process + values and returns a set of sample paths representing the data. + """ + + return DrawMatheronPaths( + model, + sample_shape=sample_shape, + prior_sampler=prior_sampler, + update_strategy=update_strategy, + **kwargs, + ) + + +@DrawMatheronPaths.register(ModelListGP) +def _draw_matheron_paths_ModelListGP(model: ModelListGP, **kwargs: Any): + return PathList([draw_matheron_paths(m, **kwargs) for m in model.models]) + + +@DrawMatheronPaths.register(ExactGP) +def _draw_matheron_paths_ExactGP( + model: ExactGP, + *, + sample_shape: Size, + prior_sampler: TPathwisePriorSampler, + update_strategy: TPathwiseUpdate, +) -> MatheronPath: + (train_X,) = get_train_inputs(model, transformed=True) + train_Y = get_train_targets(model, transformed=True) + with delattr_ctx(model, "outcome_transform"): + # Generate draws from the prior + prior_paths = prior_sampler(model=model, sample_shape=sample_shape) + sample_values = prior_paths.forward(train_X) + + # Compute pathwise updates + update_paths = update_strategy( + model=model, + sample_values=sample_values, + train_targets=train_Y, + ) + + return MatheronPath( + prior_paths=prior_paths, + update_paths=update_paths, + output_transform=get_output_transform(model), + ) + + +@DrawMatheronPaths.register((ApproximateGP, ApproximateGPyTorchModel)) +def _draw_matheron_paths_ApproximateGP( + model: Union[ApproximateGP, ApproximateGPyTorchModel], + *, + sample_shape: Size, + prior_sampler: TPathwisePriorSampler, + update_strategy: TPathwiseUpdate, + **kwargs: Any, +) -> MatheronPath: + # Note: Inducing points are assumed to be pre-transformed + Z = ( + model.model.variational_strategy.inducing_points + if isinstance(model, ApproximateGPyTorchModel) + else model.variational_strategy.inducing_points + ) + with delattr_ctx(model, "outcome_transform"): + # Generate draws from the prior + prior_paths = prior_sampler(model=model, sample_shape=sample_shape) + sample_values = prior_paths.forward(Z) # `forward` bypasses transforms + + # Compute pathwise updates + update_paths = update_strategy(model=model, sample_values=sample_values) + + return MatheronPath( + prior_paths=prior_paths, + update_paths=update_paths, + output_transform=get_output_transform(model), + ) diff --git a/botorch/sampling/pathwise/prior_samplers.py b/botorch/sampling/pathwise/prior_samplers.py new file mode 100644 index 0000000000..e03d6bee5e --- /dev/null +++ b/botorch/sampling/pathwise/prior_samplers.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Any, Callable, List, Optional + +from botorch.models.approximate_gp import ApproximateGPyTorchModel +from botorch.models.model_list_gp_regression import ModelListGP +from botorch.sampling.pathwise.features import gen_kernel_features +from botorch.sampling.pathwise.features.generators import TKernelFeatureMapGenerator +from botorch.sampling.pathwise.paths import GeneralizedLinearPath, PathList, SamplePath +from botorch.sampling.pathwise.utils import ( + get_input_transform, + get_output_transform, + get_train_inputs, + TInputTransform, + TOutputTransform, +) +from botorch.utils.dispatcher import Dispatcher +from botorch.utils.sampling import draw_sobol_normal_samples +from gpytorch.kernels import Kernel +from gpytorch.models import ApproximateGP, ExactGP, GP +from gpytorch.variational import _VariationalStrategy +from torch import Size, Tensor +from torch.nn import Module + +TPathwisePriorSampler = Callable[[GP, Size], SamplePath] +DrawKernelFeaturePaths = Dispatcher("draw_kernel_feature_paths") + + +def draw_kernel_feature_paths( + model: GP, sample_shape: Size, **kwargs: Any +) -> GeneralizedLinearPath: + r"""Draws functions from a Bayesian-linear-model-based approximation to a GP prior. + + When evaluted, sample paths produced by this method return Tensors with dimensions + `sample_dims x batch_dims x [joint_dim]`, where `joint_dim` denotes the penultimate + dimension of the input tensor. For multioutput models, outputs are returned as the + final batch dimension. + + Args: + model: The prior over functions. + sample_shape: The shape of the sample paths to be drawn. + """ + return DrawKernelFeaturePaths(model, sample_shape=sample_shape, **kwargs) + + +def _draw_kernel_feature_paths_fallback( + num_inputs: int, + mean_module: Optional[Module], + covar_module: Kernel, + sample_shape: Size, + num_features: int = 1024, + map_generator: TKernelFeatureMapGenerator = gen_kernel_features, + input_transform: Optional[TInputTransform] = None, + output_transform: Optional[TOutputTransform] = None, + weight_generator: Optional[Callable[[Size], Tensor]] = None, +) -> GeneralizedLinearPath: + + # Generate a kernel feature map + feature_map = map_generator( + kernel=covar_module, + num_inputs=num_inputs, + num_outputs=num_features, + ) + + # Sample random weights with which to combine kernel features + if weight_generator is None: + weight = draw_sobol_normal_samples( + n=sample_shape.numel() * covar_module.batch_shape.numel(), + d=feature_map.num_outputs, + device=covar_module.device, + dtype=covar_module.dtype, + ).reshape(sample_shape + covar_module.batch_shape + (feature_map.num_outputs,)) + else: + weight = weight_generator( + sample_shape + covar_module.batch_shape + (feature_map.num_outputs,) + ).to(device=covar_module.device, dtype=covar_module.dtype) + + # Return the sample paths + return GeneralizedLinearPath( + feature_map=feature_map, + weight=weight, + bias_module=mean_module, + input_transform=input_transform, + output_transform=output_transform, + ) + + +@DrawKernelFeaturePaths.register(ExactGP) +def _draw_kernel_feature_paths_ExactGP( + model: ExactGP, **kwargs: Any +) -> GeneralizedLinearPath: + (train_X,) = get_train_inputs(model, transformed=False) + return _draw_kernel_feature_paths_fallback( + num_inputs=train_X.shape[-1], + mean_module=model.mean_module, + covar_module=model.covar_module, + input_transform=get_input_transform(model), + output_transform=get_output_transform(model), + **kwargs, + ) + + +@DrawKernelFeaturePaths.register(ModelListGP) +def _draw_kernel_feature_paths_list( + model: ModelListGP, + join: Optional[Callable[[List[Tensor]], Tensor]] = None, + **kwargs: Any, +) -> PathList: + paths = [draw_kernel_feature_paths(m, **kwargs) for m in model.models] + return PathList(paths=paths, join=join) + + +@DrawKernelFeaturePaths.register(ApproximateGPyTorchModel) +def _draw_kernel_feature_paths_ApproximateGPyTorchModel( + model: ApproximateGPyTorchModel, **kwargs: Any +) -> GeneralizedLinearPath: + (train_X,) = get_train_inputs(model, transformed=False) + return DrawKernelFeaturePaths( + model.model, + num_inputs=train_X.shape[-1], + input_transform=get_input_transform(model), + output_transform=get_output_transform(model), + **kwargs, + ) + + +@DrawKernelFeaturePaths.register(ApproximateGP) +def _draw_kernel_feature_paths_ApproximateGP( + model: ApproximateGP, **kwargs: Any +) -> GeneralizedLinearPath: + return DrawKernelFeaturePaths(model, model.variational_strategy, **kwargs) + + +@DrawKernelFeaturePaths.register(ApproximateGP, _VariationalStrategy) +def _draw_kernel_feature_paths_ApproximateGP_fallback( + model: ApproximateGP, + _: _VariationalStrategy, + *, + num_inputs: int, + **kwargs: Any, +) -> GeneralizedLinearPath: + return _draw_kernel_feature_paths_fallback( + num_inputs=num_inputs, + mean_module=model.mean_module, + covar_module=model.covar_module, + **kwargs, + ) diff --git a/botorch/sampling/pathwise/update_strategies.py b/botorch/sampling/pathwise/update_strategies.py new file mode 100644 index 0000000000..d1e49985bf --- /dev/null +++ b/botorch/sampling/pathwise/update_strategies.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Any, Callable, Optional, Union + +import torch +from botorch.models.approximate_gp import ApproximateGPyTorchModel +from botorch.models.transforms.input import InputTransform +from botorch.sampling.pathwise.features import KernelEvaluationMap +from botorch.sampling.pathwise.paths import GeneralizedLinearPath, SamplePath +from botorch.sampling.pathwise.utils import ( + get_input_transform, + get_train_inputs, + get_train_targets, + TInputTransform, +) +from botorch.utils.dispatcher import Dispatcher +from botorch.utils.types import DEFAULT, NoneType +from gpytorch.kernels.kernel import Kernel +from gpytorch.likelihoods import _GaussianLikelihoodBase, Likelihood +from gpytorch.models import ApproximateGP, ExactGP, GP +from gpytorch.variational import VariationalStrategy +from linear_operator.operators import ( + LinearOperator, + SumLinearOperator, + ZeroLinearOperator, +) +from torch import Tensor + +TPathwiseUpdate = Callable[[GP, Tensor], SamplePath] +GaussianUpdate = Dispatcher("gaussian_update") + + +def gaussian_update( + model: GP, + sample_values: Tensor, + likelihood: Optional[Likelihood] = DEFAULT, + **kwargs: Any, +) -> GeneralizedLinearPath: + r"""Computes a Gaussian pathwise update in exact arithmetic: + + .. code-block:: text + + (f | y)(·) = f(·) + Cov(f(·), y) Cov(y, y)^{-1} (y - f(X) - ε), + \_______________________________________/ + V + "Gaussian pathwise update" + + where `=` denotes equality in distribution, :math:`f \sim GP(0, k)`, + :math:`y \sim N(f(X), \Sigma)`, and :math:`\epsilon \sim N(0, \Sigma)`. + For more information, see [wilson2020sampling]_ and [wilson2021pathwise]_. + + Args: + model: A Gaussian process prior together with a likelihood. + sample_values: Assumed values for :math:`f(X)`. + likelihood: An optional likelihood used to help define the desired + update. Defaults to `model.likelihood` if it exists else None. + """ + if likelihood is DEFAULT: + likelihood = getattr(model, "likelihood", None) + + return GaussianUpdate(model, likelihood, sample_values=sample_values, **kwargs) + + +def _gaussian_update_exact( + kernel: Kernel, + points: Tensor, + target_values: Tensor, + sample_values: Tensor, + noise_covariance: Optional[Union[Tensor, LinearOperator]] = None, + scale_tril: Optional[Union[Tensor, LinearOperator]] = None, + input_transform: Optional[TInputTransform] = None, +) -> GeneralizedLinearPath: + # Prepare Cholesky factor of `Cov(y, y)` and noise sample values as needed + if isinstance(noise_covariance, (NoneType, ZeroLinearOperator)): + scale_tril = kernel(points).cholesky() if scale_tril is None else scale_tril + else: + noise_values = torch.randn_like(sample_values).unsqueeze(-1) + noise_values = noise_covariance.cholesky() @ noise_values + sample_values = sample_values + noise_values.squeeze(-1) + scale_tril = ( + SumLinearOperator(kernel(points), noise_covariance).cholesky() + if scale_tril is None + else scale_tril + ) + + # Solve for `Cov(y, y)^{-1}(Y - f(X) - ε)` + errors = target_values - sample_values + weight = torch.cholesky_solve(errors.unsqueeze(-1), scale_tril.to_dense()) + + # Define update feature map and paths + feature_map = KernelEvaluationMap( + kernel=kernel, + points=points, + input_transform=input_transform, + ) + return GeneralizedLinearPath(feature_map=feature_map, weight=weight.squeeze(-1)) + + +@GaussianUpdate.register(ExactGP, _GaussianLikelihoodBase) +def _gaussian_update_ExactGP( + model: ExactGP, + likelihood: _GaussianLikelihoodBase, + *, + sample_values: Tensor, + target_values: Optional[Tensor] = None, + points: Optional[Tensor] = None, + noise_covariance: Optional[Union[Tensor, LinearOperator]] = None, + scale_tril: Optional[Union[Tensor, LinearOperator]] = None, + **ignore: Any, +) -> GeneralizedLinearPath: + if points is None: + (points,) = get_train_inputs(model, transformed=True) + + if target_values is None: + target_values = get_train_targets(model, transformed=True) + + if noise_covariance is None: + noise_covariance = likelihood.noise_covar(shape=points.shape[:-1]) + + return _gaussian_update_exact( + kernel=model.covar_module, + points=points, + target_values=target_values, + sample_values=sample_values, + noise_covariance=noise_covariance, + scale_tril=scale_tril, + input_transform=get_input_transform(model), + ) + + +@GaussianUpdate.register(ApproximateGPyTorchModel, (Likelihood, NoneType)) +def _gaussian_update_ApproximateGPyTorchModel( + model: ApproximateGPyTorchModel, + likelihood: Union[Likelihood, NoneType], + **kwargs: Any, +) -> GeneralizedLinearPath: + return GaussianUpdate( + model.model, likelihood, input_transform=get_input_transform(model), **kwargs + ) + + +@GaussianUpdate.register(ApproximateGP, (Likelihood, NoneType)) +def _gaussian_update_ApproximateGP( + model: ApproximateGP, likelihood: Union[Likelihood, NoneType], **kwargs: Any +) -> GeneralizedLinearPath: + return GaussianUpdate(model, model.variational_strategy, **kwargs) + + +@GaussianUpdate.register(ApproximateGP, VariationalStrategy) +def _gaussian_update_ApproximateGP_VariationalStrategy( + model: ApproximateGP, + _: VariationalStrategy, + *, + sample_values: Tensor, + target_values: Optional[Tensor] = None, + noise_covariance: Optional[Union[Tensor, LinearOperator]] = None, + input_transform: Optional[InputTransform] = None, + **ignore: Any, +) -> GeneralizedLinearPath: + # TODO: Account for jitter added by `psd_safe_cholesky` + if not isinstance(noise_covariance, (NoneType, ZeroLinearOperator)): + raise NotImplementedError( + f"`noise_covariance` argument not yet supported for {type(model)}." + ) + + # Inducing points `Z` are assumed to live in transformed space + batch_shape = model.covar_module.batch_shape + v = model.variational_strategy + Z = v.inducing_points + L = v._cholesky_factor(v(Z, prior=True).lazy_covariance_matrix).to( + dtype=sample_values.dtype + ) + + # Generate whitened inducing variables `u`, then location-scale transform + if target_values is None: + u = v.variational_distribution.rsample( + sample_values.shape[: sample_values.ndim - len(batch_shape) - 1], + ) + target_values = model.mean_module(Z) + (u @ L.transpose(-1, -2)) + + return _gaussian_update_exact( + kernel=model.covar_module, + points=Z, + target_values=target_values, + sample_values=sample_values, + scale_tril=L, + input_transform=input_transform, + ) diff --git a/botorch/sampling/pathwise/utils.py b/botorch/sampling/pathwise/utils.py new file mode 100644 index 0000000000..045659dc29 --- /dev/null +++ b/botorch/sampling/pathwise/utils.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Callable, Iterable, List, Optional, overload, Tuple, Union + +import torch +from botorch.models.approximate_gp import SingleTaskVariationalGP +from botorch.models.gpytorch import GPyTorchModel +from botorch.models.model import Model, ModelList +from botorch.models.transforms.input import InputTransform +from botorch.models.transforms.outcome import OutcomeTransform +from botorch.utils.dispatcher import Dispatcher +from gpytorch.kernels import ScaleKernel +from gpytorch.kernels.kernel import Kernel +from torch import LongTensor, Tensor +from torch.nn import Module, ModuleList + +TInputTransform = Union[InputTransform, Callable[[Tensor], Tensor]] +TOutputTransform = Union[OutcomeTransform, Callable[[Tensor], Tensor]] +GetTrainInputs = Dispatcher("get_train_inputs") +GetTrainTargets = Dispatcher("get_train_targets") + + +class TransformedModuleMixin: + r"""Mixin that wraps a module's __call__ method with optional transforms.""" + input_transform: Optional[TInputTransform] + output_transform: Optional[TOutputTransform] + + def __call__(self, values: Tensor, *args: Any, **kwargs: Any) -> Tensor: + input_transform = getattr(self, "input_transform", None) + if input_transform is not None: + values = ( + input_transform.forward(values) + if isinstance(input_transform, InputTransform) + else input_transform(values) + ) + + output = super().__call__(values, *args, **kwargs) + output_transform = getattr(self, "output_transform", None) + if output_transform is None: + return output + + return ( + output_transform.untransform(output)[0] + if isinstance(output_transform, OutcomeTransform) + else output_transform(output) + ) + + +class TensorTransform(ABC, Module): + r"""Abstract base class for transforms that map tensor to tensor.""" + + @abstractmethod + def forward(self, values: Tensor, **kwargs: Any) -> Tensor: + pass # pragma: no cover + + +class ChainedTransform(TensorTransform): + r"""A composition of TensorTransforms.""" + + def __init__(self, *transforms: TensorTransform): + r"""Initializes a ChainedTransform instance. + + Args: + transforms: A set of transforms to be applied from right to left. + """ + super().__init__() + self.transforms = ModuleList(transforms) + + def forward(self, values: Tensor) -> Tensor: + for transform in reversed(self.transforms): + values = transform(values) + return values + + +class SineCosineTransform(TensorTransform): + r"""A transform that returns concatenated sine and cosine features.""" + + def __init__(self, scale: Optional[Tensor] = None): + r"""Initializes a SineCosineTransform instance. + + Args: + scale: An optional tensor used to rescale the module's outputs. + """ + super().__init__() + self.scale = scale + + def forward(self, values: Tensor) -> Tensor: + sincos = torch.concat([values.sin(), values.cos()], dim=-1) + return sincos if self.scale is None else self.scale * sincos + + +class InverseLengthscaleTransform(TensorTransform): + r"""A transform that divides its inputs by a kernels lengthscales.""" + + def __init__(self, kernel: Kernel): + r"""Initializes an InverseLengthscaleTransform instance. + + Args: + kernel: The kernel whose lengthscales are to be used. + """ + if not kernel.has_lengthscale: + raise RuntimeError(f"{type(kernel)} does not implement `lengthscale`.") + + super().__init__() + self.kernel = kernel + + def forward(self, values: Tensor) -> Tensor: + return self.kernel.lengthscale.reciprocal() * values + + +class OutputscaleTransform(TensorTransform): + r"""A transform that multiplies its inputs by the square root of a + kernel's outputscale.""" + + def __init__(self, kernel: ScaleKernel): + r"""Initializes an OutputscaleTransform instance. + + Args: + kernel: A ScaleKernel whose `outputscale` is to be used. + """ + super().__init__() + self.kernel = kernel + + def forward(self, values: Tensor) -> Tensor: + outputscale = ( + self.kernel.outputscale[..., None, None] + if self.kernel.batch_shape + else self.kernel.outputscale + ) + return outputscale.sqrt() * values + + +class FeatureSelector(TensorTransform): + r"""A transform that returns a subset of its input's features. + along a given tensor dimension.""" + + def __init__(self, indices: Iterable[int], dim: Union[int, LongTensor] = -1): + r"""Initializes a FeatureSelector instance. + + Args: + indices: A LongTensor of feature indices. + dim: The dimensional along which to index features. + """ + super().__init__() + self.register_buffer("dim", dim if torch.is_tensor(dim) else torch.tensor(dim)) + self.register_buffer( + "indices", indices if torch.is_tensor(indices) else torch.tensor(indices) + ) + + def forward(self, values: Tensor) -> Tensor: + return values.index_select(dim=self.dim, index=self.indices) + + +class OutcomeUntransformer(TensorTransform): + r"""Module acting as a bridge for `OutcomeTransform.untransform`.""" + + def __init__( + self, + transform: OutcomeTransform, + num_outputs: Union[int, LongTensor], + ): + r"""Initializes an OutcomeUntransformer instance. + + Args: + transform: The wrapped OutcomeTransform instance. + num_outputs: The number of outcome features that the + OutcomeTransform transforms. + """ + super().__init__() + self.transform = transform + self.register_buffer( + "num_outputs", + num_outputs if torch.is_tensor(num_outputs) else torch.tensor(num_outputs), + ) + + def forward(self, values: Tensor) -> Tensor: + # OutcomeTransforms expect an explicit output dimension in the final position. + if self.num_outputs == 1: # BoTorch has suppressed the output dimension + output_values, _ = self.transform.untransform(values.unsqueeze(-1)) + return output_values.squeeze(-1) + + # BoTorch has moved the output dimension inside as the final batch dimension. + output_values, _ = self.transform.untransform(values.transpose(-2, -1)) + return output_values.transpose(-2, -1) + + +def get_input_transform(model: GPyTorchModel) -> Optional[InputTransform]: + r"""Returns a model's input_transform or None.""" + return getattr(model, "input_transform", None) + + +def get_output_transform(model: GPyTorchModel) -> Optional[OutcomeUntransformer]: + r"""Returns a wrapped version of a model's outcome_transform or None.""" + transform = getattr(model, "outcome_transform", None) + if transform is None: + return None + + return OutcomeUntransformer(transform=transform, num_outputs=model.num_outputs) + + +@overload +def get_train_inputs(model: Model, transformed: bool = False) -> Tuple[Tensor, ...]: + pass # pragma: no cover + + +@overload +def get_train_inputs(model: ModelList, transformed: bool = False) -> List[...]: + pass # pragma: no cover + + +def get_train_inputs(model: Model, transformed: bool = False): + return GetTrainInputs(model, transformed=transformed) + + +@GetTrainInputs.register(Model) +def _get_train_inputs_Model(model: Model, transformed: bool = False) -> Tuple[Tensor]: + if not transformed: + original_train_input = getattr(model, "_original_train_inputs", None) + if torch.is_tensor(original_train_input): + return (original_train_input,) + + (X,) = model.train_inputs + transform = get_input_transform(model) + if transform is None: + return (X,) + + if model.training: + return (transform.forward(X) if transformed else X,) + return (X if transformed else transform.untransform(X),) + + +@GetTrainInputs.register(SingleTaskVariationalGP) +def _get_train_inputs_SingleTaskVariationalGP( + model: SingleTaskVariationalGP, transformed: bool = False +) -> Tuple[Tensor]: + (X,) = model.model.train_inputs + if model.training != transformed: + return (X,) + + transform = get_input_transform(model) + if transform is None: + return (X,) + + return (transform.forward(X) if model.training else transform.untransform(X),) + + +@GetTrainInputs.register(ModelList) +def _get_train_inputs_ModelList( + model: ModelList, transformed: bool = False +) -> List[...]: + return [get_train_inputs(m, transformed=transformed) for m in model.models] + + +@overload +def get_train_targets(model: Model, transformed: bool = False) -> Tensor: + pass # pragma: no cover + + +@overload +def get_train_targets(model: ModelList, transformed: bool = False) -> List[...]: + pass # pragma: no cover + + +def get_train_targets(model: Model, transformed: bool = False): + return GetTrainTargets(model, transformed=transformed) + + +@GetTrainTargets.register(Model) +def _get_train_targets_Model(model: Model, transformed: bool = False) -> Tensor: + Y = model.train_targets + + # Note: Avoid using `get_output_transform` here since it creates a Module + transform = getattr(model, "outcome_transform", None) + if transformed or transform is None: + return Y + + if model.num_outputs == 1: + return transform.untransform(Y.unsqueeze(-1))[0].squeeze(-1) + return transform.untransform(Y.transpose(-2, -1))[0].transpose(-2, -1) + + +@GetTrainTargets.register(SingleTaskVariationalGP) +def _get_train_targets_SingleTaskVariationalGP( + model: Model, transformed: bool = False +) -> Tensor: + Y = model.model.train_targets + transform = getattr(model, "outcome_transform", None) + if transformed or transform is None: + return Y + + if model.num_outputs == 1: + return transform.untransform(Y.unsqueeze(-1))[0].squeeze(-1) + + # SingleTaskVariationalGP.__init__ doesn't bring the multitoutpout dimension inside + return transform.untransform(Y)[0] + + +@GetTrainTargets.register(ModelList) +def _get_train_targets_ModelList( + model: ModelList, transformed: bool = False +) -> List[...]: + return [get_train_targets(m, transformed=transformed) for m in model.models] diff --git a/botorch/utils/context_managers.py b/botorch/utils/context_managers.py index 6257239793..e04873930f 100644 --- a/botorch/utils/context_managers.py +++ b/botorch/utils/context_managers.py @@ -24,7 +24,7 @@ class TensorCheckpoint(NamedTuple): @contextmanager -def del_attribute_ctx( +def delattr_ctx( instance: object, *attrs: str, enforce_hasattr: bool = False ) -> Generator[None, None, None]: r"""Contextmanager for temporarily deleting attributes.""" diff --git a/botorch/utils/types.py b/botorch/utils/types.py new file mode 100644 index 0000000000..6f4f0ffe21 --- /dev/null +++ b/botorch/utils/types.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + + +NoneType = type(None) # stop gap for the return of NoneType in 3.10 + + +class _DefaultType(type): + r""" + Private class whose sole instance `DEFAULT` is as a special indicator + representing that a default value should be assigned to an argument. + Typically used in cases where `None` is an allowed argument. + """ + + +DEFAULT = _DefaultType("DEFAULT", (), {}) diff --git a/sphinx/source/sampling.rst b/sphinx/source/sampling.rst index d58685f8e2..0113a50d43 100644 --- a/sphinx/source/sampling.rst +++ b/sphinx/source/sampling.rst @@ -46,3 +46,42 @@ Stochastic Samplers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.sampling.stochastic_samplers :members: + + +Pathwise Sampling +------------------------------------------- + +Feature Maps +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.sampling.pathwise.features.maps + :members: + +Feature Map Generators +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.sampling.pathwise.features.generators + :members: + +Sample Paths +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.sampling.pathwise.paths + :members: + +Pathwise Prior Samplers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.sampling.pathwise.prior_samplers + :members: + +Pathwise Posterior Samplers +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.sampling.pathwise.posterior_samplers + :members: + +Pathwise Update Strategies +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.sampling.pathwise.update_strategies + :members: + +Utilities +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.sampling.pathwise.utils + :members: diff --git a/sphinx/source/utils.rst b/sphinx/source/utils.rst index b5ed6d3bfb..b49cfc2e62 100644 --- a/sphinx/source/utils.rst +++ b/sphinx/source/utils.rst @@ -78,6 +78,11 @@ Feasible Volume .. automodule:: botorch.utils.feasible_volume :members: +Types and Type Hints +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.utils.types + :members: + Constants ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.utils.constants diff --git a/test/posteriors/test_higher_order.py b/test/posteriors/test_higher_order.py index a2c285b0ee..a11dd07c2e 100644 --- a/test/posteriors/test_higher_order.py +++ b/test/posteriors/test_higher_order.py @@ -9,7 +9,7 @@ from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.models.higher_order_gp import HigherOrderGP from botorch.posteriors.higher_order import HigherOrderGPPosterior -from botorch.sampling import IIDNormalSampler +from botorch.sampling.normal import IIDNormalSampler from botorch.utils.testing import BotorchTestCase diff --git a/test/posteriors/test_multitask.py b/test/posteriors/test_multitask.py index fbce444529..a5cd15b63e 100644 --- a/test/posteriors/test_multitask.py +++ b/test/posteriors/test_multitask.py @@ -9,7 +9,7 @@ from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.models.multitask import KroneckerMultiTaskGP from botorch.posteriors.multitask import MultitaskGPPosterior -from botorch.sampling import IIDNormalSampler +from botorch.sampling.normal import IIDNormalSampler from botorch.utils.testing import BotorchTestCase diff --git a/test/sampling/pathwise/__init__.py b/test/sampling/pathwise/__init__.py new file mode 100644 index 0000000000..4b87eb9e4d --- /dev/null +++ b/test/sampling/pathwise/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/sampling/pathwise/features/__init__.py b/test/sampling/pathwise/features/__init__.py new file mode 100644 index 0000000000..4b87eb9e4d --- /dev/null +++ b/test/sampling/pathwise/features/__init__.py @@ -0,0 +1,5 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/test/sampling/pathwise/features/test_generators.py b/test/sampling/pathwise/features/test_generators.py new file mode 100644 index 0000000000..a55a1a64eb --- /dev/null +++ b/test/sampling/pathwise/features/test_generators.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from math import ceil +from unittest.mock import patch + +import torch +from botorch.exceptions.errors import UnsupportedError +from botorch.sampling.pathwise.features import generators +from botorch.sampling.pathwise.features.generators import gen_kernel_features +from botorch.sampling.pathwise.features.maps import FeatureMap +from botorch.utils.testing import BotorchTestCase +from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel +from gpytorch.kernels.kernel import Kernel +from torch import Size, Tensor + + +class TestFeatureGenerators(BotorchTestCase): + def setUp(self, seed: int = 0) -> None: + super().setUp() + + self.kernels = [] + self.num_inputs = d = 2 + self.num_features = 4096 + for kernel in ( + MaternKernel(nu=0.5, batch_shape=Size([])), + MaternKernel(nu=1.5, ard_num_dims=1, active_dims=[0]), + ScaleKernel(MaternKernel(nu=2.5, ard_num_dims=d, batch_shape=Size([2]))), + ScaleKernel( + RBFKernel(ard_num_dims=1, batch_shape=Size([2, 2])), active_dims=[1] + ), + ): + kernel.to( + dtype=torch.float32 if (seed % 2) else torch.float64, device=self.device + ) + with torch.random.fork_rng(): + torch.manual_seed(seed) + kern = kernel.base_kernel if isinstance(kernel, ScaleKernel) else kernel + kern.lengthscale = 0.1 + 0.2 * torch.rand_like(kern.lengthscale) + seed += 1 + + self.kernels.append(kernel) + + def test_gen_kernel_features(self): + for seed, kernel in enumerate(self.kernels): + with torch.random.fork_rng(): + torch.random.manual_seed(seed) + feature_map = gen_kernel_features( + kernel=kernel, + num_inputs=self.num_inputs, + num_outputs=self.num_features, + ) + + n = 4 + m = ceil(n * kernel.batch_shape.numel() ** -0.5) + for input_batch_shape in ((n**2,), (m, *kernel.batch_shape, m)): + X = torch.rand( + (*input_batch_shape, self.num_inputs), + device=kernel.device, + dtype=kernel.dtype, + ) + self._test_gen_kernel_features(kernel, feature_map, X) + + def _test_gen_kernel_features( + self, kernel: Kernel, feature_map: FeatureMap, X: Tensor, atol: float = 3.0 + ): + with self.subTest("test_initialization"): + self.assertEqual(feature_map.weight.dtype, kernel.dtype) + self.assertEqual(feature_map.weight.device, kernel.device) + self.assertEqual( + feature_map.weight.shape[-1], + self.num_inputs + if kernel.active_dims is None + else len(kernel.active_dims), + ) + + with self.subTest("test_covariance"): + features = feature_map(X) + test_shape = torch.broadcast_shapes( + (*X.shape[:-1], self.num_features), kernel.batch_shape + (1, 1) + ) + self.assertEqual(features.shape, test_shape) + K0 = features @ features.transpose(-2, -1) + K1 = kernel(X).to_dense() + self.assertTrue( + K0.allclose(K1, atol=atol * self.num_features**-0.5, rtol=0) + ) + + # Test passing the wrong dimensional shape to `weight_generator` + with self.assertRaisesRegex(UnsupportedError, "2-dim"), patch.object( + generators, + "_gen_fourier_features", + side_effect=lambda **kwargs: kwargs["weight_generator"](Size([])), + ): + gen_kernel_features( + kernel=kernel, + num_inputs=self.num_inputs, + num_outputs=self.num_features, + ) + + # Test requesting an odd number of features + with self.assertRaisesRegex(UnsupportedError, "Expected an even number"): + gen_kernel_features( + kernel=kernel, num_inputs=self.num_inputs, num_outputs=3 + ) diff --git a/test/sampling/pathwise/features/test_maps.py b/test/sampling/pathwise/features/test_maps.py new file mode 100644 index 0000000000..842d2164c9 --- /dev/null +++ b/test/sampling/pathwise/features/test_maps.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import torch +from botorch.sampling.pathwise.features import KernelEvaluationMap, KernelFeatureMap +from botorch.utils.testing import BotorchTestCase +from gpytorch.kernels import MaternKernel +from torch import Size + + +class TestFeatureMaps(BotorchTestCase): + def test_kernel_evaluation_map(self): + kernel = MaternKernel(nu=2.5, ard_num_dims=2, batch_shape=Size([2])) + kernel.to(device=self.device) + with torch.random.fork_rng(): + torch.manual_seed(0) + kernel.lengthscale = 0.1 + 0.3 * torch.rand_like(kernel.lengthscale) + + with self.assertRaisesRegex(RuntimeError, "Shape mismatch"): + KernelEvaluationMap(kernel=kernel, points=torch.rand(4, 3, 2)) + + for dtype in (torch.float32, torch.float64): + kernel.to(dtype=dtype) + X0, X1 = torch.rand(5, 2, dtype=dtype, device=self.device).split([2, 3]) + kernel_map = KernelEvaluationMap(kernel=kernel, points=X1) + self.assertEqual(kernel_map.batch_shape, kernel.batch_shape) + self.assertEqual(kernel_map.num_outputs, X1.shape[-1]) + self.assertTrue(kernel_map(X0).to_dense().equal(kernel(X0, X1).to_dense())) + + with patch.object( + kernel_map, "output_transform", new=lambda z: torch.concat([z, z], dim=-1) + ): + self.assertEqual(kernel_map.num_outputs, 2 * X1.shape[-1]) + + def test_kernel_feature_map(self): + d = 2 + m = 3 + weight = torch.rand(m, d, device=self.device) + bias = torch.rand(m, device=self.device) + kernel = MaternKernel(nu=2.5, batch_shape=Size([3])).to(self.device) + feature_map = KernelFeatureMap( + kernel=kernel, + weight=weight, + bias=bias, + input_transform=MagicMock(side_effect=lambda x: x), + output_transform=MagicMock(side_effect=lambda z: z.exp()), + ) + + X = torch.rand(2, d, device=self.device) + features = feature_map(X) + feature_map.input_transform.assert_called_once_with(X) + feature_map.output_transform.assert_called_once() + self.assertTrue((X @ weight.transpose(-2, -1) + bias).exp().equal(features)) + + # Test batch_shape and num_outputs + self.assertIs(feature_map.batch_shape, kernel.batch_shape) + self.assertEqual(feature_map.num_outputs, weight.shape[-2]) + with patch.object(feature_map, "output_transform", new=None): + self.assertEqual(feature_map.num_outputs, weight.shape[-2]) diff --git a/test/sampling/pathwise/helpers.py b/test/sampling/pathwise/helpers.py new file mode 100644 index 0000000000..6740365839 --- /dev/null +++ b/test/sampling/pathwise/helpers.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Tuple + +from botorch.models.transforms.outcome import Standardize +from torch import Size, Tensor + + +def get_sample_moments(samples: Tensor, sample_shape: Size) -> Tuple[Tensor, Tensor]: + sample_dim = len(sample_shape) + samples = samples.view(-1, *samples.shape[sample_dim:]) + loc = samples.mean(dim=0) + residuals = (samples - loc).permute(*range(1, samples.ndim), 0) + return loc, (residuals @ residuals.transpose(-2, -1)) / sample_shape.numel() + + +def standardize_moments( + transform: Standardize, + loc: Tensor, + covariance_matrix: Tensor, +) -> Tuple[Tensor, Tensor]: + + m = transform.means.squeeze().unsqueeze(-1) + s = transform.stdvs.squeeze().reciprocal().unsqueeze(-1) + loc = s * (loc - m) + correlation_matrix = s.unsqueeze(-1) * covariance_matrix * s.unsqueeze(-2) + return loc, correlation_matrix diff --git a/test/sampling/pathwise/test_paths.py b/test/sampling/pathwise/test_paths.py new file mode 100644 index 0000000000..3b24430f53 --- /dev/null +++ b/test/sampling/pathwise/test_paths.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import torch +from botorch.exceptions.errors import UnsupportedError +from botorch.sampling.pathwise.paths import PathDict, PathList, SamplePath +from botorch.utils.testing import BotorchTestCase +from torch.nn import ModuleDict, ModuleList + + +class IdentityPath(SamplePath): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class TestGenericPaths(BotorchTestCase): + def test_path_dict(self): + with self.assertRaisesRegex(UnsupportedError, "must be preceded by a join"): + PathDict(output_transform="foo") + + A = IdentityPath() + B = IdentityPath() + + # Test __init__ + module_dict = ModuleDict({"0": A, "1": B}) + path_dict = PathDict(paths={"0": A, "1": B}) + self.assertTrue(path_dict.paths is not module_dict) + + path_dict = PathDict(paths=module_dict) + self.assertIs(path_dict.paths, module_dict) + + # Test __call__ + x = torch.rand(3, device=self.device) + output = path_dict(x) + self.assertIsInstance(output, dict) + self.assertTrue(x.equal(output.pop("0"))) + self.assertTrue(x.equal(output.pop("1"))) + self.assertTrue(not output) + + path_dict.join = torch.stack + output = path_dict(x) + self.assertIsInstance(output, torch.Tensor) + self.assertEqual(output.shape, (2,) + x.shape) + self.assertTrue(output.eq(x).all()) + + # Test `dict`` methods + self.assertEqual(len(path_dict), 2) + for key, val, (key_0, val_0), (key_1, val_1), key_2 in zip( + path_dict, + path_dict.values(), + path_dict.items(), + path_dict.paths.items(), + path_dict.keys(), + ): + self.assertEqual(1, len({key, key_0, key_1, key_2})) + self.assertEqual(1, len({val, val_0, val_1, path_dict[key]})) + + path_dict["1"] = A # test __setitem__ + self.assertIs(path_dict.paths["1"], A) + + del path_dict["1"] # test __delitem__ + self.assertEqual(("0",), tuple(path_dict)) + + def test_path_list(self): + with self.assertRaisesRegex(UnsupportedError, "must be preceded by a join"): + PathList(output_transform="foo") + + # Test __init__ + A = IdentityPath() + B = IdentityPath() + module_list = ModuleList((A, B)) + path_list = PathList(paths=list(module_list)) + self.assertTrue(path_list.paths is not module_list) + + path_list = PathList(paths=module_list) + self.assertIs(path_list.paths, module_list) + + # Test __call__ + x = torch.rand(3, device=self.device) + output = path_list(x) + self.assertIsInstance(output, list) + self.assertTrue(x.equal(output.pop())) + self.assertTrue(x.equal(output.pop())) + self.assertTrue(not output) + + path_list.join = torch.stack + output = path_list(x) + self.assertIsInstance(output, torch.Tensor) + self.assertEqual(output.shape, (2,) + x.shape) + self.assertTrue(output.eq(x).all()) + + # Test `list` methods + self.assertEqual(len(path_list), 2) + for key, (path, path_0) in enumerate(zip(path_list, path_list.paths)): + self.assertEqual(1, len({path, path_0, path_list[key]})) + + path_list[1] = A # test __setitem__ + self.assertIs(path_list.paths[1], A) + + del path_list[1] # test __delitem__ + self.assertEqual((A,), tuple(path_list)) diff --git a/test/sampling/pathwise/test_posterior_samplers.py b/test/sampling/pathwise/test_posterior_samplers.py new file mode 100644 index 0000000000..d6b3ca6fc2 --- /dev/null +++ b/test/sampling/pathwise/test_posterior_samplers.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections import defaultdict +from copy import deepcopy +from itertools import product + +import torch +from botorch.models import ( + FixedNoiseGP, + ModelListGP, + SingleTaskGP, + SingleTaskVariationalGP, +) +from botorch.models.transforms.input import Normalize +from botorch.models.transforms.outcome import Standardize +from botorch.sampling.pathwise import draw_matheron_paths, MatheronPath, PathList +from botorch.sampling.pathwise.utils import get_train_inputs +from botorch.utils.testing import BotorchTestCase +from gpytorch.kernels import MaternKernel, ScaleKernel +from torch import Size +from torch.nn.functional import pad + +from .helpers import get_sample_moments, standardize_moments + + +class TestPosteriorSamplers(BotorchTestCase): + def setUp(self) -> None: + super().setUp() + self.models = defaultdict(list) + + seed = 0 + for kernel in ( + ScaleKernel(MaternKernel(nu=2.5, ard_num_dims=2, batch_shape=Size([]))), + ): + with torch.random.fork_rng(): + torch.manual_seed(seed) + tkwargs = {"device": self.device, "dtype": torch.float64} + + base = kernel.base_kernel if isinstance(kernel, ScaleKernel) else kernel + base.lengthscale = 0.1 + 0.3 * torch.rand_like(base.lengthscale) + kernel.to(**tkwargs) + + uppers = 1 + 9 * torch.rand(base.lengthscale.shape[-1], **tkwargs) + bounds = pad(uppers.unsqueeze(0), (0, 0, 1, 0)) + + X = uppers * torch.rand(4, base.lengthscale.shape[-1], **tkwargs) + Y = 10 * kernel(X).cholesky() @ torch.randn(4, 1, **tkwargs) + if kernel.batch_shape: + Y = Y.squeeze(-1).transpose(0, 1) # n x m + + input_transform = Normalize(d=X.shape[-1], bounds=bounds) + outcome_transform = Standardize(m=Y.shape[-1]) + + # SingleTaskGP in eval mode + self.models[SingleTaskGP].append( + SingleTaskGP( + train_X=X, + train_Y=Y, + covar_module=deepcopy(kernel), + input_transform=deepcopy(input_transform), + outcome_transform=deepcopy(outcome_transform), + ) + .to(**tkwargs) + .eval() + ) + + # FixedNoiseGP in train mode + self.models[FixedNoiseGP].append( + FixedNoiseGP( + train_X=X, + train_Y=Y, + train_Yvar=0.01 * torch.rand_like(Y), + covar_module=kernel, + input_transform=input_transform, + outcome_transform=outcome_transform, + ).to(**tkwargs) + ) + + # SingleTaskVariationalGP in train mode + self.models[SingleTaskVariationalGP].append( + SingleTaskVariationalGP( + train_X=X, + train_Y=Y, + covar_module=kernel, + input_transform=input_transform, + outcome_transform=outcome_transform, + ).to(**tkwargs) + ) + + seed += 1 + + def test_draw_matheron_paths(self): + for seed, models in enumerate(self.models.values()): + for model, sample_shape in product(models, [Size([1024]), Size([32, 32])]): + with torch.random.fork_rng(): + torch.random.manual_seed(seed) + paths = draw_matheron_paths(model=model, sample_shape=sample_shape) + self.assertIsInstance(paths, MatheronPath) + self._test_draw_matheron_paths(model, paths, sample_shape) + + with self.subTest("test_model_list"): + model_list = ModelListGP( + self.models[SingleTaskGP][0], self.models[FixedNoiseGP][0] + ) + path_list = draw_matheron_paths(model_list, sample_shape=sample_shape) + (train_X,) = get_train_inputs(model_list.models[0], transformed=False) + X = torch.zeros( + 4, train_X.shape[-1], dtype=train_X.dtype, device=self.device + ) + sample_list = path_list(X) + self.assertIsInstance(path_list, PathList) + self.assertIsInstance(sample_list, list) + self.assertEqual(len(sample_list), len(path_list.paths)) + + def _test_draw_matheron_paths(self, model, paths, sample_shape, atol=3): + (train_X,) = get_train_inputs(model, transformed=False) + X = torch.rand(16, train_X.shape[-1], dtype=train_X.dtype, device=self.device) + + # Evaluate sample paths and compute sample statistics + samples = paths(X) + batch_shape = ( + model.model.covar_module.batch_shape + if isinstance(model, SingleTaskVariationalGP) + else model.covar_module.batch_shape + ) + self.assertEqual(samples.shape, sample_shape + batch_shape + X.shape[-2:-1]) + + sample_moments = get_sample_moments(samples, sample_shape) + if hasattr(model, "outcome_transform"): + # Do this instead of untransforming exact moments + sample_moments = standardize_moments( + model.outcome_transform, *sample_moments + ) + + if model.training: + model.eval() + mvn = model(model.transform_inputs(X)) + model.train() + else: + mvn = model(model.transform_inputs(X)) + exact_moments = (mvn.loc, mvn.covariance_matrix) + + # Compare moments + num_features = paths["prior_paths"].weight.shape[-1] + tol = atol * (num_features**-0.5 + sample_shape.numel() ** -0.5) + for exact, estimate in zip(exact_moments, sample_moments): + self.assertTrue(exact.allclose(estimate, atol=tol, rtol=0)) diff --git a/test/sampling/pathwise/test_prior_samplers.py b/test/sampling/pathwise/test_prior_samplers.py new file mode 100644 index 0000000000..80c37c0428 --- /dev/null +++ b/test/sampling/pathwise/test_prior_samplers.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections import defaultdict +from copy import deepcopy +from itertools import product +from unittest.mock import MagicMock + +import torch +from botorch.models import ( + FixedNoiseGP, + ModelListGP, + SingleTaskGP, + SingleTaskVariationalGP, +) +from botorch.models.transforms.input import Normalize +from botorch.models.transforms.outcome import Standardize +from botorch.sampling.pathwise import ( + draw_kernel_feature_paths, + GeneralizedLinearPath, + PathList, +) +from botorch.sampling.pathwise.utils import get_train_inputs +from botorch.utils.testing import BotorchTestCase +from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel +from torch import Size +from torch.nn.functional import pad + +from .helpers import get_sample_moments, standardize_moments + + +class TestPriorSamplers(BotorchTestCase): + def setUp(self) -> None: + super().setUp() + self.models = defaultdict(list) + self.num_features = 1024 + + seed = 0 + for kernel in ( + MaternKernel(nu=2.5, ard_num_dims=2, batch_shape=Size([])), + ScaleKernel(RBFKernel(ard_num_dims=2, batch_shape=Size([2]))), + ): + with torch.random.fork_rng(): + torch.manual_seed(seed) + tkwargs = {"device": self.device, "dtype": torch.float64} + + base = kernel.base_kernel if isinstance(kernel, ScaleKernel) else kernel + base.lengthscale = 0.1 + 0.3 * torch.rand_like(base.lengthscale) + kernel.to(**tkwargs) + + uppers = 1 + 9 * torch.rand(base.lengthscale.shape[-1], **tkwargs) + bounds = pad(uppers.unsqueeze(0), (0, 0, 1, 0)) + + X = uppers * torch.rand(4, base.lengthscale.shape[-1], **tkwargs) + Y = 10 * kernel(X).cholesky() @ torch.randn(4, 1, **tkwargs) + if kernel.batch_shape: + Y = Y.squeeze(-1).transpose(0, 1) # n x m + + input_transform = Normalize(d=X.shape[-1], bounds=bounds) + outcome_transform = Standardize(m=Y.shape[-1]) + + # SingleTaskGP in eval mode + self.models[SingleTaskGP].append( + SingleTaskGP( + train_X=X, + train_Y=Y, + covar_module=deepcopy(kernel), + input_transform=deepcopy(input_transform), + outcome_transform=deepcopy(outcome_transform), + ) + .to(**tkwargs) + .eval() + ) + + # FixedNoiseGP in train mode + self.models[FixedNoiseGP].append( + FixedNoiseGP( + train_X=X, + train_Y=Y, + train_Yvar=0.01 * torch.rand_like(Y), + covar_module=kernel, + input_transform=input_transform, + outcome_transform=outcome_transform, + ).to(**tkwargs) + ) + + # SingleTaskVariationalGP in train mode + # When batched, uses a multitask format which break the tests below + if not kernel.batch_shape: + self.models[SingleTaskVariationalGP].append( + SingleTaskVariationalGP( + train_X=X, + train_Y=Y, + covar_module=kernel, + input_transform=input_transform, + outcome_transform=outcome_transform, + ).to(**tkwargs) + ) + + seed += 1 + + def test_draw_kernel_feature_paths(self): + for seed, models in enumerate(self.models.values()): + for model, sample_shape in product(models, [Size([1024]), Size([2, 512])]): + with torch.random.fork_rng(): + torch.random.manual_seed(seed) + paths = draw_kernel_feature_paths( + model=model, + sample_shape=sample_shape, + num_features=self.num_features, + ) + self.assertIsInstance(paths, GeneralizedLinearPath) + self._test_draw_kernel_feature_paths(model, paths, sample_shape) + + with self.subTest("test_model_list"): + model_list = ModelListGP( + self.models[SingleTaskGP][0], self.models[FixedNoiseGP][0] + ) + path_list = draw_kernel_feature_paths( + model=model_list, + sample_shape=sample_shape, + num_features=self.num_features, + ) + (train_X,) = get_train_inputs(model_list.models[0], transformed=False) + X = torch.zeros( + 4, train_X.shape[-1], dtype=train_X.dtype, device=self.device + ) + sample_list = path_list(X) + self.assertIsInstance(path_list, PathList) + self.assertIsInstance(sample_list, list) + self.assertEqual(len(sample_list), len(path_list.paths)) + + with self.subTest("test_initialization"): + model = self.models[SingleTaskGP][0] + sample_shape = torch.Size([16]) + weight_generator = MagicMock() + draw_kernel_feature_paths( + model=model, + sample_shape=sample_shape, + num_features=self.num_features, + weight_generator=weight_generator, + ) + weight_generator.assert_called_once_with( + sample_shape + model.covar_module.batch_shape + (self.num_features,) + ) + + def _test_draw_kernel_feature_paths(self, model, paths, sample_shape, atol=3): + (train_X,) = get_train_inputs(model, transformed=False) + X = torch.rand(16, train_X.shape[-1], dtype=train_X.dtype, device=self.device) + + # Evaluate sample paths + samples = paths(X) + batch_shape = ( + model.model.covar_module.batch_shape + if isinstance(model, SingleTaskVariationalGP) + else model.covar_module.batch_shape + ) + self.assertEqual(samples.shape, sample_shape + batch_shape + X.shape[-2:-1]) + + # Calculate sample statistics + sample_moments = get_sample_moments(samples, sample_shape) + if hasattr(model, "outcome_transform"): + # Do this instead of untransforming exact moments + sample_moments = standardize_moments( + model.outcome_transform, *sample_moments + ) + + # Compute prior distribution + prior = model.forward(X if model.training else model.input_transform(X)) + exact_moments = (prior.loc, prior.covariance_matrix) + + # Compare moments + tol = atol * (paths.weight.shape[-1] ** -0.5 + sample_shape.numel() ** -0.5) + for exact, estimate in zip(exact_moments, sample_moments): + self.assertTrue(exact.allclose(estimate, atol=tol, rtol=0)) diff --git a/test/sampling/pathwise/test_update_strategies.py b/test/sampling/pathwise/test_update_strategies.py new file mode 100644 index 0000000000..6f52015dc2 --- /dev/null +++ b/test/sampling/pathwise/test_update_strategies.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections import defaultdict +from copy import deepcopy +from itertools import chain +from unittest.mock import patch + +import torch +from botorch.models import FixedNoiseGP, SingleTaskGP, SingleTaskVariationalGP +from botorch.models.transforms.input import Normalize +from botorch.models.transforms.outcome import Standardize +from botorch.sampling.pathwise import ( + draw_kernel_feature_paths, + gaussian_update, + GeneralizedLinearPath, + KernelEvaluationMap, +) +from botorch.sampling.pathwise.utils import get_train_inputs, get_train_targets +from botorch.utils.context_managers import delattr_ctx +from botorch.utils.testing import BotorchTestCase +from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel +from gpytorch.likelihoods import BernoulliLikelihood +from gpytorch.utils.cholesky import psd_safe_cholesky +from linear_operator.operators import ZeroLinearOperator +from torch import Size +from torch.nn.functional import pad + + +class TestPathwiseUpdates(BotorchTestCase): + def setUp(self) -> None: + super().setUp() + self.models = defaultdict(list) + + seed = 0 + for kernel in ( + RBFKernel(ard_num_dims=2), + ScaleKernel(MaternKernel(nu=2.5, ard_num_dims=2, batch_shape=Size([2]))), + ): + with torch.random.fork_rng(): + torch.manual_seed(seed) + tkwargs = {"device": self.device, "dtype": torch.float64} + + base = kernel.base_kernel if isinstance(kernel, ScaleKernel) else kernel + base.lengthscale = 0.1 + 0.3 * torch.rand_like(base.lengthscale) + kernel.to(**tkwargs) + + uppers = 1 + 9 * torch.rand(base.lengthscale.shape[-1], **tkwargs) + bounds = pad(uppers.unsqueeze(0), (0, 0, 1, 0)) + + X = uppers * torch.rand(4, base.lengthscale.shape[-1], **tkwargs) + Y = 10 * kernel(X).cholesky() @ torch.randn(4, 1, **tkwargs) + if kernel.batch_shape: + Y = Y.squeeze(-1).transpose(0, 1) # n x m + + input_transform = Normalize(d=X.shape[-1], bounds=bounds) + outcome_transform = Standardize(m=Y.shape[-1]) + + # SingleTaskGP in eval mode + self.models[SingleTaskGP].append( + SingleTaskGP( + train_X=X, + train_Y=Y, + covar_module=deepcopy(kernel), + input_transform=deepcopy(input_transform), + outcome_transform=deepcopy(outcome_transform), + ) + .to(**tkwargs) + .eval() + ) + + # FixedNoiseGP in train mode + self.models[FixedNoiseGP].append( + FixedNoiseGP( + train_X=X, + train_Y=Y, + train_Yvar=0.01 * torch.rand_like(Y), + covar_module=kernel, + input_transform=input_transform, + outcome_transform=outcome_transform, + ).to(**tkwargs) + ) + + # SingleTaskVariationalGP in train mode + # When batched, uses a multitask format which break the tests below + if not kernel.batch_shape: + self.models[SingleTaskVariationalGP].append( + SingleTaskVariationalGP( + train_X=X, + train_Y=Y, + covar_module=kernel, + input_transform=input_transform, + outcome_transform=outcome_transform, + ).to(**tkwargs) + ) + + seed += 1 + + def test_gaussian_updates(self): + for seed, model in enumerate(chain.from_iterable(self.models.values())): + with torch.random.fork_rng(): + torch.manual_seed(seed) + self._test_gaussian_updates(model) + + def _test_gaussian_updates(self, model): + sample_shape = torch.Size([3]) + + # Extract exact conditions and precompute covariances + if isinstance(model, SingleTaskVariationalGP): + Z = model.model.variational_strategy.inducing_points + X = ( + Z + if model.input_transform is None + else model.input_transform.untransform(Z) + ) + U = torch.randn(len(Z), device=Z.device, dtype=Z.dtype) + Kuu = Kmm = model.model.covar_module(Z) + noise_values = None + else: + (X,) = get_train_inputs(model, transformed=False) + (Z,) = get_train_inputs(model, transformed=True) + U = get_train_targets(model, transformed=True) + Kmm = model.forward(X if model.training else Z).lazy_covariance_matrix + Kuu = Kmm + model.likelihood.noise_covar(shape=Z.shape[:-1]) + noise_values = torch.randn( + *sample_shape, *U.shape, device=U.device, dtype=U.dtype + ) + + # Disable sampling of noise variables `e` used to obtain `y = f + e` + with delattr_ctx(model, "outcome_transform"), patch.object( + torch, + "randn_like", + return_value=noise_values, + ): + prior_paths = draw_kernel_feature_paths(model, sample_shape=sample_shape) + sample_values = prior_paths(X) + update_paths = gaussian_update( + model=model, + sample_values=sample_values, + target_values=U, + ) + + # Test initialization + self.assertIsInstance(update_paths, GeneralizedLinearPath) + self.assertIsInstance(update_paths.feature_map, KernelEvaluationMap) + self.assertTrue(update_paths.feature_map.points.equal(Z)) + self.assertIs( + update_paths.feature_map.input_transform, + getattr(model, "input_transform", None), + ) + + # Compare with manually computed update weights `Cov(y, y)^{-1} (y - f - e)` + Luu = psd_safe_cholesky(Kuu.to_dense()) + errors = U - sample_values + if noise_values is not None: + errors -= ( + model.likelihood.noise_covar(shape=Z.shape[:-1]).cholesky() + @ noise_values.unsqueeze(-1) + ).squeeze(-1) + weight = torch.cholesky_solve(errors.unsqueeze(-1), Luu).squeeze(-1) + self.assertTrue(weight.allclose(update_paths.weight)) + + # Compare with manually computed update values at test locations + Z2 = torch.rand(16, Z.shape[-1], device=self.device, dtype=Z.dtype) + X2 = ( + model.input_transform.untransform(Z2) + if hasattr(model, "input_transform") + else Z2 + ) + features = update_paths.feature_map(X2) + expected_updates = (features @ update_paths.weight.unsqueeze(-1)).squeeze(-1) + actual_updates = update_paths(X2) + self.assertTrue(actual_updates.allclose(expected_updates)) + + # Test passing `noise_covariance` + m = Z.shape[-2] + update_paths = gaussian_update( + model=model, + sample_values=sample_values, + target_values=U, + noise_covariance=ZeroLinearOperator(m, m, dtype=X.dtype), + ) + Lmm = psd_safe_cholesky(Kmm.to_dense()) + errors = U - sample_values + weight = torch.cholesky_solve(errors.unsqueeze(-1), Lmm).squeeze(-1) + self.assertTrue(weight.allclose(update_paths.weight)) + + if isinstance(model, SingleTaskVariationalGP): + # Test passing non-zero `noise_covariance`` + with patch.object(model, "likelihood", new=BernoulliLikelihood()): + with self.assertRaisesRegex(NotImplementedError, "not yet supported"): + gaussian_update( + model=model, + sample_values=sample_values, + noise_covariance="foo", + ) + else: + # Test exact models with non-Gaussian likelihoods + with patch.object(model, "likelihood", new=BernoulliLikelihood()): + with self.assertRaises(NotImplementedError): + gaussian_update(model=model, sample_values=sample_values) diff --git a/test/sampling/pathwise/test_utils.py b/test/sampling/pathwise/test_utils.py new file mode 100644 index 0000000000..b69bf298bb --- /dev/null +++ b/test/sampling/pathwise/test_utils.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from unittest.mock import patch + +import torch +from botorch.models import SingleTaskGP, SingleTaskVariationalGP +from botorch.models.model_list_gp_regression import ModelListGP +from botorch.models.transforms.input import Normalize +from botorch.models.transforms.outcome import Standardize +from botorch.sampling.pathwise.utils import ( + get_input_transform, + get_output_transform, + get_train_inputs, + get_train_targets, + InverseLengthscaleTransform, + OutcomeUntransformer, +) +from botorch.utils.context_managers import delattr_ctx +from botorch.utils.testing import BotorchTestCase +from gpytorch.kernels import MaternKernel, ScaleKernel + + +class TestTransforms(BotorchTestCase): + def test_inverse_lengthscale_transform(self): + tkwargs = {"device": self.device, "dtype": torch.float64} + kernel = MaternKernel(nu=2.5, ard_num_dims=3).to(**tkwargs) + with self.assertRaisesRegex(RuntimeError, "does not implement `lengthscale`"): + InverseLengthscaleTransform(ScaleKernel(kernel)) + + x = torch.rand(3, 3, **tkwargs) + transform = InverseLengthscaleTransform(kernel) + self.assertTrue(transform(x).equal(kernel.lengthscale.reciprocal() * x)) + + def test_outcome_untransformer(self): + for untransformer in ( + OutcomeUntransformer(transform=Standardize(m=1), num_outputs=1), + OutcomeUntransformer(transform=Standardize(m=2), num_outputs=2), + ): + with torch.random.fork_rng(): + torch.random.manual_seed(0) + y = torch.rand(untransformer.num_outputs, 4, device=self.device) + x = untransformer.transform(y.T)[0].T + self.assertTrue(y.allclose(untransformer(x))) + + +class TestGetters(BotorchTestCase): + def setUp(self): + super().setUp() + with torch.random.fork_rng(): + torch.random.manual_seed(0) + train_X = torch.rand(5, 2) + train_Y = torch.randn(5, 2) + + self.models = [] + for num_outputs in (1, 2): + self.models.append( + SingleTaskGP( + train_X=train_X, + train_Y=train_Y[:, :num_outputs], + input_transform=Normalize(d=2), + outcome_transform=Standardize(m=num_outputs), + ) + ) + + self.models.append( + SingleTaskVariationalGP( + train_X=train_X, + train_Y=train_Y[:, :num_outputs], + input_transform=Normalize(d=2), + outcome_transform=Standardize(m=num_outputs), + ) + ) + + def test_get_input_transform(self): + for model in self.models: + self.assertIs(get_input_transform(model), model.input_transform) + + def test_get_output_transform(self): + for model in self.models: + transform = get_output_transform(model) + self.assertIsInstance(transform, OutcomeUntransformer) + self.assertIs(transform.transform, model.outcome_transform) + + def test_get_train_inputs(self): + for model in self.models: + model.train() + X = ( + model.model.train_inputs[0] + if isinstance(model, SingleTaskVariationalGP) + else model.train_inputs[0] + ) + Z = model.input_transform(X) + train_inputs = get_train_inputs(model, transformed=False) + self.assertIsInstance(train_inputs, tuple) + self.assertEqual(len(train_inputs), 1) + + self.assertTrue(X.equal(get_train_inputs(model, transformed=False)[0])) + self.assertTrue(Z.equal(get_train_inputs(model, transformed=True)[0])) + + model.eval() + self.assertTrue(X.equal(get_train_inputs(model, transformed=False)[0])) + self.assertTrue(Z.equal(get_train_inputs(model, transformed=True)[0])) + with delattr_ctx(model, "input_transform"), patch.object( + model, "_original_train_inputs", new=None + ): + self.assertTrue(Z.equal(get_train_inputs(model, transformed=False)[0])) + self.assertTrue(Z.equal(get_train_inputs(model, transformed=True)[0])) + + with self.subTest("test_model_list"): + model_list = ModelListGP(*self.models) + input_list = get_train_inputs(model_list) + self.assertIsInstance(input_list, list) + self.assertEqual(len(input_list), len(self.models)) + for model, train_inputs in zip(model_list.models, input_list): + for a, b in zip(train_inputs, get_train_inputs(model)): + self.assertTrue(a.equal(b)) + + def test_get_train_targets(self): + for model in self.models: + model.train() + if isinstance(model, SingleTaskVariationalGP): + F = model.model.train_targets + Y = model.outcome_transform.untransform(F)[0].squeeze(dim=0) + else: + F = model.train_targets + Y = OutcomeUntransformer(model.outcome_transform, model.num_outputs)(F) + + self.assertTrue(F.equal(get_train_targets(model, transformed=True))) + self.assertTrue(Y.equal(get_train_targets(model, transformed=False))) + + model.eval() + self.assertTrue(F.equal(get_train_targets(model, transformed=True))) + self.assertTrue(Y.equal(get_train_targets(model, transformed=False))) + with delattr_ctx(model, "outcome_transform"): + self.assertTrue(F.equal(get_train_targets(model, transformed=True))) + self.assertTrue(F.equal(get_train_targets(model, transformed=False))) + + with self.subTest("test_model_list"): + model_list = ModelListGP(*self.models) + target_list = get_train_targets(model_list) + self.assertIsInstance(target_list, list) + self.assertEqual(len(target_list), len(self.models)) + for model, Y in zip(self.models, target_list): + self.assertTrue(Y.equal(get_train_targets(model))) diff --git a/test/utils/test_context_managers.py b/test/utils/test_context_managers.py index a9e3d141d3..16b442b56d 100644 --- a/test/utils/test_context_managers.py +++ b/test/utils/test_context_managers.py @@ -10,7 +10,7 @@ import torch from botorch.utils.context_managers import ( - del_attribute_ctx, + delattr_ctx, module_rollback_ctx, parameter_rollback_ctx, requires_grad_ctx, @@ -29,11 +29,11 @@ def setUp(self): param = Parameter(values.to(torch.float64), requires_grad=bool(i % 2)) module.register_parameter(name, param) - def test_del_attribute_ctx(self): + def test_delattr_ctx(self): # Test temporary removal of attributes a = self.module.a b = self.module.b - with del_attribute_ctx(self.module, "a", "b"): + with delattr_ctx(self.module, "a", "b"): self.assertIsNone(getattr(self.module, "a", None)) self.assertIsNone(getattr(self.module, "b", None)) self.assertTrue(self.module.c is not None) @@ -43,7 +43,7 @@ def test_del_attribute_ctx(self): self.assertTrue(self.module.b.equal(b)) with self.assertRaisesRegex(ValueError, "Attribute .* missing"): - with del_attribute_ctx(self.module, "z", enforce_hasattr=True): + with delattr_ctx(self.module, "z", enforce_hasattr=True): pass # pragma: no cover def test_requires_grad_ctx(self): @@ -127,7 +127,7 @@ def test_module_rollback_ctx(self): self.assertTrue(self.module.c.equal(c)) # Test that items in checkpoint get inserted into state_dict - with del_attribute_ctx(self.module, "a"): + with delattr_ctx(self.module, "a"): with self.assertRaisesRegex( # should fail when attempting to rollback RuntimeError, r'Unexpected key\(s\) in state_dict: "a"' ):