Skip to content

Commit 936a0a7

Browse files
Merge pull request #1 from j-wilson/export-D40662482
sampling.pathwise (meta-pytorch#1463)
2 parents 2437241 + 4758b62 commit 936a0a7

File tree

13 files changed

+1143
-0
lines changed

13 files changed

+1143
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from botorch.sampling.pathwise.basis import (
9+
BasisExpansion,
10+
fourier_feature_initializer,
11+
GeneralizedLinearBasis,
12+
KernelBasis,
13+
)
14+
from botorch.sampling.pathwise.matheron import draw_matheron_paths, MatheronPath
15+
from botorch.sampling.pathwise.paths import AffinePath, CompositePath, SamplePath
16+
from botorch.sampling.pathwise.prior_samplers import draw_bayes_linear_paths
17+
from botorch.sampling.pathwise.update_strategies import exact_update
18+
19+
20+
__all__ = [
21+
"BasisExpansion",
22+
"CompositePath",
23+
"exact_update",
24+
"fourier_feature_initializer",
25+
"draw_matheron_paths",
26+
"draw_bayes_linear_paths",
27+
"GeneralizedLinearBasis",
28+
"KernelBasis",
29+
"KernelFeatureInitializer",
30+
"AffinePath",
31+
"MatheronPath",
32+
"SamplePath",
33+
]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from botorch.sampling.pathwise.basis.expansions import (
9+
BasisExpansion,
10+
GeneralizedLinearBasis,
11+
KernelBasis,
12+
)
13+
from botorch.sampling.pathwise.basis.initializers import (
14+
fourier_feature_initializer,
15+
GeneralizedLinearInitialization,
16+
GeneralizedLinearInitializer,
17+
KernelFeatureInitializer,
18+
)
19+
20+
__all__ = [
21+
"BasisExpansion",
22+
"fourier_feature_initializer",
23+
"GeneralizedLinearBasis",
24+
"GeneralizedLinearInitialization",
25+
"GeneralizedLinearInitializer",
26+
"KernelBasis",
27+
"KernelFeatureInitializer",
28+
]
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from typing import Optional
10+
11+
import torch
12+
from botorch.sampling.pathwise.basis.initializers import GeneralizedLinearInitializer
13+
from botorch.sampling.pathwise.utils.common import TensorTransform
14+
from gpytorch.kernels import Kernel
15+
from torch import Size, Tensor
16+
from torch.nn import Module, Parameter
17+
18+
19+
class BasisExpansion(Module):
20+
output_shape: Size
21+
batch_shape: Optional[Size]
22+
23+
24+
class GeneralizedLinearBasis(BasisExpansion):
25+
r"""Generalized linear basis functions:
26+
`phi(x) = output_transform(input_transform(x)^T @ weight + bias)`."""
27+
28+
input_transform: Optional[TensorTransform]
29+
output_transform: Optional[TensorTransform]
30+
31+
def __init__(
32+
self,
33+
initializer: GeneralizedLinearInitializer,
34+
output_shape: Size,
35+
batch_shape: Optional[Size] = None,
36+
) -> None:
37+
super().__init__()
38+
self.batch_shape = Size() if batch_shape is None else batch_shape
39+
self.output_shape = output_shape
40+
self.initializer = initializer
41+
self._initialized = Parameter(
42+
torch.zeros([], dtype=torch.bool), requires_grad=False
43+
)
44+
45+
def forward(self, x: Tensor) -> Tensor:
46+
if not self._initialized:
47+
self.initialize(input_shape=x.shape[-1:])
48+
49+
x = x if self.input_transform is None else self.input_transform(x)
50+
z = x @ self.weight.transpose(-2, -1)
51+
if self.bias is not None:
52+
z = z + self.bias
53+
54+
return z if self.output_transform is None else self.output_transform(z)
55+
56+
def initialize(self, input_shape: Size) -> None:
57+
weight, bias, input_transform, output_transform = self.initializer(
58+
input_shape=input_shape,
59+
output_shape=self.batch_shape + self.output_shape,
60+
)
61+
self.bias = bias
62+
self.weight = weight
63+
self.input_transform = input_transform
64+
self.output_transform = output_transform
65+
self._initialized[...] = True
66+
67+
68+
class KernelBasis(BasisExpansion):
69+
def __init__(self, kernel: Kernel, centers: Tensor) -> None:
70+
r"""Canonical basis functions $\phi_{i}(x) = k(x, z_{i})."""
71+
try:
72+
torch.broadcast_shapes(centers.shape[:-2], kernel.batch_shape)
73+
except RuntimeError as e:
74+
raise RuntimeError(
75+
f"Shape mismatch: `centers` has shape {centers.shape}, "
76+
f"but kernel.batch_shape={kernel.batch_shape}."
77+
) from e
78+
79+
super().__init__()
80+
self.kernel = kernel
81+
self.centers = centers
82+
83+
def forward(self, x: Tensor) -> Tensor:
84+
return self.kernel(x, self.centers)
85+
86+
@property
87+
def batch_shape(self) -> Size:
88+
return self.kernel.batch_shape
89+
90+
@property
91+
def output_shape(self) -> Size:
92+
return self.centers.shape[-2:-1]
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
.. [rahimi2007random]
9+
A. Rahimi and B. Recht. Random features for large-scale kernel machines.
10+
Advances in neural information processing systems 20 (2007).
11+
12+
.. [sutherland2015error]
13+
D. J. Sutherland and J. Schneider. On the error of random Fourier features.
14+
arXiv preprint arXiv:1506.02785 (2015).
15+
"""
16+
17+
from __future__ import annotations
18+
19+
from typing import Any, Callable, NamedTuple, Optional, Protocol, runtime_checkable
20+
21+
import torch
22+
from botorch.sampling.pathwise.utils.common import TensorTransform
23+
from botorch.utils.dispatcher import Dispatcher
24+
from botorch.utils.sampling import draw_sobol_normal_samples
25+
from gpytorch import kernels
26+
from torch import Size, Tensor
27+
from torch.distributions import Gamma
28+
29+
FourierFeatureInitializer = Dispatcher("fourier_feature_initializer")
30+
NoneType = type(None)
31+
32+
33+
class GeneralizedLinearInitialization(NamedTuple):
34+
r"""Initialization for generalized linear basis functions `phi(x)`. Formally:
35+
`phi(x) = output_transform(input_transform(x)^T @ weight + bias)`.
36+
"""
37+
weight: Tensor
38+
bias: Optional[Tensor]
39+
input_transform: Optional[TensorTransform]
40+
output_transform: Optional[TensorTransform]
41+
42+
43+
@runtime_checkable
44+
class GeneralizedLinearInitializer(Protocol):
45+
def __call__(
46+
self, input_shape: Size, output_shape: Size, **kwargs: Any
47+
) -> GeneralizedLinearInitialization:
48+
pass
49+
50+
51+
@runtime_checkable
52+
class KernelFeatureInitializer(Protocol):
53+
def __call__(
54+
self,
55+
kernel: kernels.Kernel,
56+
input_shape: Size,
57+
output_shape: Size,
58+
**kwargs: Any,
59+
) -> GeneralizedLinearInitialization:
60+
pass
61+
62+
63+
def fourier_feature_initializer(
64+
kernel: kernels.Kernel,
65+
input_shape: Size,
66+
output_shape: Size,
67+
**kwargs: Any,
68+
) -> GeneralizedLinearInitialization:
69+
return FourierFeatureInitializer(
70+
kernel, input_shape=input_shape, output_shape=output_shape, **kwargs
71+
)
72+
73+
74+
def _fourier_initializer_stationary_sincos(
75+
kernel: kernels.Kernel,
76+
weight_initializer: Callable[[Size], Tensor],
77+
input_shape: Size,
78+
output_shape: Size,
79+
) -> GeneralizedLinearInitialization:
80+
r"""Returns a (2 * l)-dimensional feature map `phi: X -> R^{2l}` whose inner product
81+
phi(x)^T phi(x') approximates the evaluation of a stationary kernel
82+
`k(x, x') = k(x - x')`. For details, see [rahimi2007random]_.
83+
84+
As argued for in [sutherland2015error]_, we use Euler's formula to represent
85+
complex exponential basis functions as pairs of trigonometric bases:
86+
87+
`phi_{i}(x) = sin(x^T w_{i})` and `phi_{i + l} = cos(x^T w_{i})`
88+
"""
89+
assert (output_shape[-1] % 2) == 0
90+
shape = output_shape[:-1] + (output_shape[-1] // 2,)
91+
scale = (2 / output_shape[-1]) ** 0.5
92+
weight = weight_initializer(Size([shape.numel(), input_shape.numel()])).reshape(
93+
*shape, *input_shape
94+
)
95+
96+
def input_transform(raw_inputs: Tensor) -> Tensor:
97+
return kernel.lengthscale.reciprocal() * raw_inputs
98+
99+
def output_transform(raw_features: Tensor) -> Tensor:
100+
return scale * torch.concat([raw_features.sin(), raw_features.cos()], dim=-1)
101+
102+
return GeneralizedLinearInitialization(
103+
weight, None, input_transform, output_transform
104+
)
105+
106+
107+
@FourierFeatureInitializer.register(kernels.RBFKernel)
108+
def _fourier_initializer_rbf(
109+
kernel: kernels.RBFKernel,
110+
*,
111+
input_shape: Size,
112+
output_shape: Size,
113+
) -> GeneralizedLinearInitialization:
114+
def _weight_initializer(shape: Size) -> Tensor:
115+
if len(shape) != 2:
116+
raise NotImplementedError
117+
118+
return draw_sobol_normal_samples(
119+
n=shape[0],
120+
d=shape[1],
121+
device=kernel.lengthscale.device,
122+
dtype=kernel.lengthscale.dtype,
123+
)
124+
125+
return _fourier_initializer_stationary_sincos(
126+
kernel=kernel,
127+
weight_initializer=_weight_initializer,
128+
input_shape=input_shape,
129+
output_shape=output_shape,
130+
)
131+
132+
133+
@FourierFeatureInitializer.register(kernels.MaternKernel)
134+
def _fourier_initializer_matern(
135+
kernel: kernels.MaternKernel,
136+
*,
137+
input_shape: Size,
138+
output_shape: Size,
139+
) -> GeneralizedLinearInitialization:
140+
def _weight_initializer(shape: Size) -> Tensor:
141+
try:
142+
n, d = shape
143+
except ValueError:
144+
raise NotImplementedError(
145+
f"Expected `shape` to be size 2, but is size {len(shape)}."
146+
)
147+
148+
dtype = kernel.lengthscale.dtype
149+
device = kernel.lengthscale.device
150+
nu = torch.tensor(kernel.nu, device=device, dtype=dtype)
151+
normals = draw_sobol_normal_samples(n=n, d=d, device=device, dtype=dtype)
152+
return Gamma(nu, nu).rsample((n, 1)).rsqrt() * normals
153+
154+
return _fourier_initializer_stationary_sincos(
155+
kernel=kernel,
156+
weight_initializer=_weight_initializer,
157+
input_shape=input_shape,
158+
output_shape=output_shape,
159+
)
160+
161+
162+
@FourierFeatureInitializer.register(kernels.ScaleKernel)
163+
def _fourier_initializer_scale(
164+
kernel: kernels.ScaleKernel,
165+
*,
166+
input_shape: Size,
167+
output_shape: Size,
168+
) -> GeneralizedLinearInitialization:
169+
170+
weight, bias, input_transform, output_transform = fourier_feature_initializer(
171+
kernel.base_kernel,
172+
input_shape=input_shape,
173+
output_shape=output_shape,
174+
)
175+
176+
def scaled_output_transform(raw_features: Tensor) -> Tensor:
177+
features = (
178+
raw_features if output_transform is None else output_transform(raw_features)
179+
)
180+
outputscale = kernel.outputscale
181+
while outputscale.ndim < features.ndim:
182+
outputscale = outputscale.unsqueeze(-1)
183+
184+
return outputscale.sqrt() * features
185+
186+
return GeneralizedLinearInitialization(
187+
weight, bias, input_transform, scaled_output_transform
188+
)

0 commit comments

Comments
 (0)