Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions botorch/optim/closures/model_closures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import Any, Callable, Dict, Optional, Sequence, Tuple

from botorch.optim.closures.core import ForwardBackwardClosure
from botorch.optim.utils import TNone
from botorch.utils.dispatcher import Dispatcher, type_bypassing_encoder
from botorch.utils.types import NoneType
from gpytorch.mlls import (
ExactMarginalLogLikelihood,
MarginalLogLikelihood,
Expand Down Expand Up @@ -151,9 +151,9 @@ def closure(**kwargs: Any) -> Tensor:
return closure


@GetLossClosure.register(MarginalLogLikelihood, object, object, TNone)
@GetLossClosure.register(MarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_fallback_internal(
mll: MarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any
mll: MarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
) -> Callable[[], Tensor]:
r"""Fallback loss closure with internally managed data."""

Expand All @@ -165,9 +165,9 @@ def closure(**kwargs: Any) -> Tensor:
return closure


@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, TNone)
@GetLossClosure.register(ExactMarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_exact_internal(
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any
mll: ExactMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
) -> Callable[[], Tensor]:
r"""ExactMarginalLogLikelihood loss closure with internally managed data."""

Expand All @@ -181,9 +181,9 @@ def closure(**kwargs: Any) -> Tensor:
return closure


@GetLossClosure.register(SumMarginalLogLikelihood, object, object, TNone)
@GetLossClosure.register(SumMarginalLogLikelihood, object, object, NoneType)
def _get_loss_closure_sum_internal(
mll: SumMarginalLogLikelihood, _: object, __: object, ___: TNone, **ignore: Any
mll: SumMarginalLogLikelihood, _: object, __: object, ___: NoneType, **ignore: Any
) -> Callable[[], Tensor]:
r"""SumMarginalLogLikelihood loss closure with internally managed data."""

Expand Down
2 changes: 1 addition & 1 deletion botorch/optim/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@
from botorch.optim.utils import (
_filter_kwargs,
_get_extra_mll_args,
DEFAULT,
get_name_filter,
get_parameters_and_bounds,
TorchAttr,
)
from botorch.optim.utils.model_utils import get_parameters
from botorch.utils.types import DEFAULT
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from gpytorch.settings import fast_computations
from numpy import ndarray
Expand Down
4 changes: 0 additions & 4 deletions botorch/optim/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
_filter_kwargs,
_handle_numerical_errors,
_warning_handler_template,
DEFAULT,
TNone,
)
from botorch.optim.utils.model_utils import (
_get_extra_mll_args,
Expand All @@ -40,7 +38,6 @@
"_warning_handler_template",
"as_ndarray",
"columnwise_clamp",
"DEFAULT",
"fix_features",
"get_name_filter",
"get_bounds_as_ndarray",
Expand All @@ -53,5 +50,4 @@
"sample_all_priors",
"set_tensors_from_ndarray_1d",
"TorchAttr",
"TNone",
]
9 changes: 0 additions & 9 deletions botorch/optim/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,6 @@
import numpy as np
from linear_operator.utils.errors import NanError, NotPSDError

TNone = type(None)


class _TDefault:
pass


DEFAULT = _TDefault()


def _filter_kwargs(function: Callable, **kwargs: Any) -> Any:
r"""Filter out kwargs that are not applicable for a given function.
Expand Down
6 changes: 4 additions & 2 deletions botorch/optim/utils/numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import numpy as np
import torch
from botorch.optim.utils.common import TNone
from botorch.utils.types import NoneType
from numpy import ndarray
from torch import Tensor

Expand Down Expand Up @@ -137,7 +137,9 @@ def set_tensors_from_ndarray_1d(

def get_bounds_as_ndarray(
parameters: Dict[str, Tensor],
bounds: Dict[str, Tuple[Union[float, Tensor, TNone], Union[float, Tensor, TNone]]],
bounds: Dict[
str, Tuple[Union[float, Tensor, NoneType], Union[float, Tensor, NoneType]]
],
) -> Optional[np.ndarray]:
r"""Helper method for converting bounds into an ndarray.

Expand Down
39 changes: 39 additions & 0 deletions botorch/sampling/pathwise/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from botorch.sampling.pathwise.features import (
gen_kernel_features,
KernelEvaluationMap,
KernelFeatureMap,
)
from botorch.sampling.pathwise.paths import (
GeneralizedLinearPath,
PathDict,
PathList,
SamplePath,
)
from botorch.sampling.pathwise.posterior_samplers import (
draw_matheron_paths,
MatheronPath,
)
from botorch.sampling.pathwise.prior_samplers import draw_kernel_feature_paths
from botorch.sampling.pathwise.update_strategies import gaussian_update


__all__ = [
"draw_matheron_paths",
"draw_kernel_feature_paths",
"gen_kernel_features",
"gaussian_update",
"GeneralizedLinearPath",
"KernelEvaluationMap",
"KernelFeatureMap",
"MatheronPath",
"SamplePath",
"PathDict",
"PathList",
]
20 changes: 20 additions & 0 deletions botorch/sampling/pathwise/features/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from botorch.sampling.pathwise.features.generators import gen_kernel_features
from botorch.sampling.pathwise.features.maps import (
FeatureMap,
KernelEvaluationMap,
KernelFeatureMap,
)

__all__ = [
"FeatureMap",
"gen_kernel_features",
"KernelEvaluationMap",
"KernelFeatureMap",
]
193 changes: 193 additions & 0 deletions botorch/sampling/pathwise/features/generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
.. [rahimi2007random]
A. Rahimi and B. Recht. Random features for large-scale kernel machines.
Advances in Neural Information Processing Systems 20 (2007).

.. [sutherland2015error]
D. J. Sutherland and J. Schneider. On the error of random Fourier features.
arXiv preprint arXiv:1506.02785 (2015).
"""

from __future__ import annotations

from typing import Any, Callable

import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.sampling.pathwise.features.maps import KernelFeatureMap
from botorch.sampling.pathwise.utils import (
ChainedTransform,
FeatureSelector,
InverseLengthscaleTransform,
OutputscaleTransform,
SineCosineTransform,
)
from botorch.utils.dispatcher import Dispatcher
from botorch.utils.sampling import draw_sobol_normal_samples
from gpytorch import kernels
from gpytorch.kernels.kernel import Kernel
from torch import Size, Tensor
from torch.distributions import Gamma

TKernelFeatureMapGenerator = Callable[[Kernel, int, int], KernelFeatureMap]
GenKernelFeatures = Dispatcher("gen_kernel_features")


def gen_kernel_features(
kernel: kernels.Kernel,
num_inputs: int,
num_outputs: int,
**kwargs: Any,
) -> KernelFeatureMap:
r"""Generates a feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^{n}` such that
:math:`k(x, x') ≈ \phi(x)^{T} \phi(x')`. For stationary kernels :math:`k`, defaults
to the method of random Fourier features. For more details, see [rahimi2007random]_
and [sutherland2015error]_.

Args:
kernel: The kernel :math:`k` to be represented via a finite-dim basis.
num_inputs: The number of input features.
num_outputs: The number of kernel features.
"""
return GenKernelFeatures(
kernel,
num_inputs=num_inputs,
num_outputs=num_outputs,
**kwargs,
)


def _gen_fourier_features(
kernel: kernels.Kernel,
weight_generator: Callable[[Size], Tensor],
num_inputs: int,
num_outputs: int,
) -> KernelFeatureMap:
r"""Generate a feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^{2l}` that
approximates a stationary kernel so that :math:`k(x, x') ≈ \phi(x)^\top \phi(x')`.

Following [sutherland2015error]_, we represent complex exponentials by pairs of
basis functions :math:`\phi_{i}(x) = \sin(x^\top w_{i})` and
:math:`\phi_{i + l} = \cos(x^\top w_{i}).

Args:
kernel: A stationary kernel :math:`k(x, x') = k(x - x')`.
weight_generator: A callable used to generate weight vectors :math:`w`.
num_inputs: The number of input features.
num_outputs: The number of Fourier features.
"""
if num_outputs % 2:
raise UnsupportedError(
f"Expected an even number of output features, but received {num_outputs=}."
)

input_transform = InverseLengthscaleTransform(kernel)
if kernel.active_dims is not None:
num_inputs = len(kernel.active_dims)
input_transform = ChainedTransform(
input_transform, FeatureSelector(indices=kernel.active_dims)
)

weight = weight_generator(
Size([kernel.batch_shape.numel() * num_outputs // 2, num_inputs])
).reshape(*kernel.batch_shape, num_outputs // 2, num_inputs)

output_transform = SineCosineTransform(
torch.tensor((2 / num_outputs) ** 0.5, device=kernel.device, dtype=kernel.dtype)
)
return KernelFeatureMap(
kernel=kernel,
weight=weight,
input_transform=input_transform,
output_transform=output_transform,
)


@GenKernelFeatures.register(kernels.RBFKernel)
def _gen_kernel_features_rbf(
kernel: kernels.RBFKernel,
*,
num_inputs: int,
num_outputs: int,
) -> KernelFeatureMap:
def _weight_generator(shape: Size) -> Tensor:
try:
n, d = shape
except ValueError:
raise UnsupportedError(
f"Expected `shape` to be 2-dimensional, but {len(shape)=}."
)

return draw_sobol_normal_samples(
n=n,
d=d,
device=kernel.lengthscale.device,
dtype=kernel.lengthscale.dtype,
)

return _gen_fourier_features(
kernel=kernel,
weight_generator=_weight_generator,
num_inputs=num_inputs,
num_outputs=num_outputs,
)


@GenKernelFeatures.register(kernels.MaternKernel)
def _gen_kernel_features_matern(
kernel: kernels.MaternKernel,
*,
num_inputs: int,
num_outputs: int,
) -> KernelFeatureMap:
def _weight_generator(shape: Size) -> Tensor:
try:
n, d = shape
except ValueError:
raise UnsupportedError(
f"Expected `shape` to be 2-dimensional, but {len(shape)=}."
)

dtype = kernel.lengthscale.dtype
device = kernel.lengthscale.device
nu = torch.tensor(kernel.nu, device=device, dtype=dtype)
normals = draw_sobol_normal_samples(n=n, d=d, device=device, dtype=dtype)
return Gamma(nu, nu).rsample((n, 1)).rsqrt() * normals

return _gen_fourier_features(
kernel=kernel,
weight_generator=_weight_generator,
num_inputs=num_inputs,
num_outputs=num_outputs,
)


@GenKernelFeatures.register(kernels.ScaleKernel)
def _gen_kernel_features_scale(
kernel: kernels.ScaleKernel,
*,
num_inputs: int,
num_outputs: int,
) -> KernelFeatureMap:
active_dims = kernel.active_dims
feature_map = gen_kernel_features(
kernel.base_kernel,
num_inputs=num_inputs if active_dims is None else len(active_dims),
num_outputs=num_outputs,
)

if active_dims is not None and active_dims is not kernel.base_kernel.active_dims:
feature_map.input_transform = ChainedTransform(
feature_map.input_transform, FeatureSelector(indices=active_dims)
)

feature_map.output_transform = ChainedTransform(
OutputscaleTransform(kernel), feature_map.output_transform
)
return feature_map
Loading