Skip to content

Commit a10d9f2

Browse files
authored
Merge e147346 into 63dd0cd
2 parents 63dd0cd + e147346 commit a10d9f2

File tree

10 files changed

+219
-50
lines changed

10 files changed

+219
-50
lines changed

botorch/acquisition/fixed_feature.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
import torch
1818
from botorch.acquisition.acquisition import AcquisitionFunction
19+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
1920
from torch import Tensor
20-
from torch.nn import Module
2121

2222

23-
class FixedFeatureAcquisitionFunction(AcquisitionFunction):
23+
class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
2424
"""A wrapper around AquisitionFunctions to fix a subset of features.
2525
2626
Example:
@@ -56,8 +56,7 @@ def __init__(
5656
combination of `Tensor`s and numbers which can be broadcasted
5757
to form a tensor with trailing dimension size of `d_f`.
5858
"""
59-
Module.__init__(self)
60-
self.acq_func = acq_function
59+
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
6160
dtype = torch.float
6261
device = torch.device("cpu")
6362
self.d = d
@@ -126,24 +125,13 @@ def forward(self, X: Tensor):
126125
X_full = self._construct_X_full(X)
127126
return self.acq_func(X_full)
128127

129-
@property
130-
def X_pending(self):
131-
r"""Return the `X_pending` of the base acquisition function."""
132-
try:
133-
return self.acq_func.X_pending
134-
except (ValueError, AttributeError):
135-
raise ValueError(
136-
f"Base acquisition function {type(self.acq_func).__name__} "
137-
"does not have an `X_pending` attribute."
138-
)
139-
140-
@X_pending.setter
141-
def X_pending(self, X_pending: Optional[Tensor]):
128+
def set_X_pending(self, X_pending: Optional[Tensor]):
142129
r"""Sets the `X_pending` of the base acquisition function."""
143130
if X_pending is not None:
144-
self.acq_func.X_pending = self._construct_X_full(X_pending)
131+
full_X_pending = self._construct_X_full(X_pending)
145132
else:
146-
self.acq_func.X_pending = X_pending
133+
full_X_pending = None
134+
self.acq_func.set_X_pending(full_X_pending)
147135

148136
def _construct_X_full(self, X: Tensor) -> Tensor:
149137
r"""Constructs the full input for the base acquisition function.

botorch/acquisition/penalized.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515

1616
import torch
1717
from botorch.acquisition.acquisition import AcquisitionFunction
18-
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
1918
from botorch.acquisition.objective import GenericMCObjective
20-
from botorch.exceptions import UnsupportedError
19+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
2120
from torch import Tensor
2221

2322

@@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
139138
return regularization_term
140139

141140

142-
class PenalizedAcquisitionFunction(AcquisitionFunction):
141+
class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
143142
r"""Single-outcome acquisition function regularized by the given penalty.
144143
145144
The usage is similar to:
@@ -161,29 +160,16 @@ def __init__(
161160
penalty_func: The regularization function.
162161
regularization_parameter: Regularization parameter used in optimization.
163162
"""
164-
super().__init__(model=raw_acqf.model)
165-
self.raw_acqf = raw_acqf
163+
AcquisitionFunction.__init__(self, model=raw_acqf.model)
164+
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf)
166165
self.penalty_func = penalty_func
167166
self.regularization_parameter = regularization_parameter
168167

169168
def forward(self, X: Tensor) -> Tensor:
170-
raw_value = self.raw_acqf(X=X)
169+
raw_value = self.acq_func(X=X)
171170
penalty_term = self.penalty_func(X)
172171
return raw_value - self.regularization_parameter * penalty_term
173172

174-
@property
175-
def X_pending(self) -> Optional[Tensor]:
176-
return self.raw_acqf.X_pending
177-
178-
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
179-
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
180-
self.raw_acqf.set_X_pending(X_pending=X_pending)
181-
else:
182-
raise UnsupportedError(
183-
"The raw acquisition function is Analytic and does not account "
184-
"for X_pending yet."
185-
)
186-
187173

188174
def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
189175
r"""Computes the group lasso regularization function for the given point.

botorch/acquisition/proximal.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import torch
1717
from botorch.acquisition import AcquisitionFunction
18+
19+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
1820
from botorch.exceptions.errors import UnsupportedError
1921
from botorch.models import ModelListGP
2022
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
@@ -25,7 +27,7 @@
2527
from torch.nn import Module
2628

2729

28-
class ProximalAcquisitionFunction(AcquisitionFunction):
30+
class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
2931
"""A wrapper around AcquisitionFunctions to add proximal weighting of the
3032
acquisition function. The acquisition function is
3133
weighted via a squared exponential centered at the last training point,
@@ -70,17 +72,14 @@ def __init__(
7072
beta: If not None, apply a softplus transform to the base acquisition
7173
function, allows negative base acquisition function values.
7274
"""
73-
Module.__init__(self)
74-
75-
self.acq_func = acq_function
75+
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
7676
model = self.acq_func.model
7777

7878
if hasattr(acq_function, "X_pending"):
7979
if acq_function.X_pending is not None:
8080
raise UnsupportedError(
8181
"Proximal acquisition function requires `X_pending` to be None."
8282
)
83-
self.X_pending = acq_function.X_pending
8483

8584
self.register_buffer("proximal_weights", proximal_weights)
8685
self.register_buffer(
@@ -91,6 +90,12 @@ def __init__(
9190

9291
_validate_model(model, proximal_weights)
9392

93+
def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
94+
r"""Sets the `X_pending` of the base acquisition function."""
95+
raise UnsupportedError(
96+
"Proximal acquisition function does not support `X_pending`."
97+
)
98+
9499
@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
95100
def forward(self, X: Tensor) -> Tensor:
96101
r"""Evaluate base acquisition function with proximal weighting.

botorch/acquisition/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from __future__ import annotations
1212

1313
import math
14-
from typing import Callable, Dict, List, Optional, Union
14+
from typing import Any, Callable, Dict, List, Optional, Union
1515

1616
import torch
1717
from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401
@@ -22,6 +22,7 @@
2222
MCAcquisitionObjective,
2323
PosteriorTransform,
2424
)
25+
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
2526
from botorch.exceptions.errors import UnsupportedError
2627
from botorch.models.fully_bayesian import MCMC_DIM
2728
from botorch.models.model import Model
@@ -253,6 +254,18 @@ def objective(Y: Tensor, X: Optional[Tensor] = None):
253254
return -(lb.clamp_max(0.0))
254255

255256

257+
def isinstance_af(
258+
__obj: object,
259+
__class_or_tuple: Union[type, tuple[Union[type, tuple[Any, ...]], ...]],
260+
) -> bool:
261+
r"""A variant of isinstance first checks for the acq_func attribute on wrapped acquisition functions."""
262+
if isinstance(__obj, AbstractAcquisitionFunctionWrapper):
263+
isinstance_base_af = isinstance(__obj.acq_func, __class_or_tuple)
264+
else:
265+
isinstance_base_af = False
266+
return isinstance_base_af or isinstance(__obj, __class_or_tuple)
267+
268+
256269
def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
257270
r"""Determine whether a given acquisition function is non-negative.
258271
@@ -267,7 +280,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool:
267280
>>> qEI = qExpectedImprovement(model, best_f=0.1)
268281
>>> is_nonnegative(qEI) # returns True
269282
"""
270-
return isinstance(
283+
return isinstance_af(
271284
acq_function,
272285
(
273286
analytic.ExpectedImprovement,

botorch/acquisition/wrapper.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
A wrapper classes around AcquisitionFunctions to modify inputs and outputs.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from abc import ABC, abstractmethod
14+
from typing import Optional
15+
16+
from botorch.acquisition.acquisition import AcquisitionFunction
17+
from torch import Tensor
18+
from torch.nn import Module
19+
20+
21+
class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC):
22+
r"""Abstract acquisition wrapper."""
23+
24+
def __init__(self, acq_function: AcquisitionFunction) -> None:
25+
Module.__init__(self)
26+
self.acq_func = acq_function
27+
28+
@property
29+
def X_pending(self) -> Optional[Tensor]:
30+
r"""Return the `X_pending` of the base acquisition function."""
31+
try:
32+
return self.acq_func.X_pending
33+
except (ValueError, AttributeError):
34+
raise ValueError(
35+
f"Base acquisition function {type(self.acq_func).__name__} "
36+
"does not have an `X_pending` attribute."
37+
)
38+
39+
def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
40+
r"""Sets the `X_pending` of the base acquisition function."""
41+
self.acq_func.set_X_pending(X_pending)
42+
43+
@abstractmethod
44+
def forward(self, X: Tensor) -> Tensor:
45+
r"""Evaluate the wrapped acquisition function on the candidate set X.
46+
47+
Args:
48+
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
49+
design points each.
50+
51+
Returns:
52+
A `(b)`-dim Tensor of acquisition function values at the given
53+
design points `X`.
54+
"""
55+
pass # pragma: no cover

sphinx/source/acquisition.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ Analytic Acquisition Function API
2121
.. autoclass:: AnalyticAcquisitionFunction
2222
:members:
2323

24+
Acquisition Function Wrapper API
25+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
26+
.. automodule:: botorch.acquisition.wrapper
27+
:members:
28+
2429
Cached Cholesky Acquisition Function API
2530
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2631
.. automodule:: botorch.acquisition.cached_cholesky
@@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions
6570
.. automodule:: botorch.acquisition.multi_objective.analytic
6671
:members:
6772
:exclude-members: MultiObjectiveAnalyticAcquisitionFunction
68-
73+
6974
Multi-Objective Joint Entropy Search Acquisition Functions
7075
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7176
.. automodule:: botorch.acquisition.multi_objective.joint_entropy_search
@@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions
8691
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8792
.. automodule:: botorch.acquisition.multi_objective.multi_fidelity
8893
:members:
89-
94+
9095
Multi-Objective Predictive Entropy Search Acquisition Functions
9196
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9297
.. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search

test/acquisition/test_fixed_feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_fixed_features(self):
8787
qEI_ff.set_X_pending(X_pending[..., :-1])
8888
self.assertAllClose(qEI.X_pending, X_pending)
8989
# test setting to None
90-
qEI_ff.X_pending = None
90+
qEI_ff.set_X_pending(None)
9191
self.assertIsNone(qEI_ff.X_pending)
9292

9393
# test gradient

test/acquisition/test_proximal.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,15 @@ def test_proximal(self):
209209

210210
# test for x_pending points
211211
pending_acq = DummyAcquisitionFunction(model)
212-
pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype))
212+
X_pending = torch.rand(3, 3, device=self.device, dtype=dtype)
213+
pending_acq.set_X_pending(X_pending)
213214
with self.assertRaises(UnsupportedError):
214215
ProximalAcquisitionFunction(pending_acq, proximal_weights)
216+
# test setting pending points
217+
pending_acq.set_X_pending(None)
218+
af = ProximalAcquisitionFunction(pending_acq, proximal_weights)
219+
with self.assertRaises(UnsupportedError):
220+
af.set_X_pending(X_pending)
215221

216222
# test model with multi-batch training inputs
217223
train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype)

test/acquisition/test_utils.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from unittest import mock
99

1010
import torch
11-
from botorch.acquisition import monte_carlo
11+
from botorch.acquisition import analytic, monte_carlo, multi_objective
12+
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
1213
from botorch.acquisition.multi_objective import (
1314
MCMultiOutputObjective,
1415
monte_carlo as moo_monte_carlo,
@@ -18,10 +19,13 @@
1819
MCAcquisitionObjective,
1920
ScalarizedPosteriorTransform,
2021
)
22+
from botorch.acquisition.proximal import ProximalAcquisitionFunction
2123
from botorch.acquisition.utils import (
2224
expand_trace_observations,
2325
get_acquisition_function,
2426
get_infeasible_cost,
27+
is_nonnegative,
28+
isinstance_af,
2529
project_to_sample_points,
2630
project_to_target_fidelity,
2731
prune_inferior_points,
@@ -606,6 +610,61 @@ def test_get_infeasible_cost(self):
606610
self.assertAllClose(M4, torch.tensor([1.0], **tkwargs))
607611

608612

613+
class TestIsNonnegative(BotorchTestCase):
614+
def test_is_nonnegative(self):
615+
nonneg_afs = (
616+
analytic.ExpectedImprovement,
617+
analytic.ConstrainedExpectedImprovement,
618+
analytic.ProbabilityOfImprovement,
619+
analytic.NoisyExpectedImprovement,
620+
monte_carlo.qExpectedImprovement,
621+
monte_carlo.qNoisyExpectedImprovement,
622+
monte_carlo.qProbabilityOfImprovement,
623+
multi_objective.analytic.ExpectedHypervolumeImprovement,
624+
multi_objective.monte_carlo.qExpectedHypervolumeImprovement,
625+
multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement,
626+
)
627+
mm = MockModel(
628+
MockPosterior(
629+
mean=torch.rand(1, 1, device=self.device),
630+
variance=torch.ones(1, 1, device=self.device),
631+
)
632+
)
633+
acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0)
634+
with mock.patch(
635+
"botorch.acquisition.utils.isinstance_af", return_value=True
636+
) as mock_isinstance_af:
637+
self.assertTrue(is_nonnegative(acq_function=acq_func))
638+
mock_isinstance_af.assert_called_once()
639+
cargs, _ = mock_isinstance_af.call_args
640+
self.assertIs(cargs[0], acq_func)
641+
self.assertEqual(cargs[1], nonneg_afs)
642+
acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0)
643+
self.assertFalse(is_nonnegative(acq_function=acq_func))
644+
645+
646+
class TestIsinstanceAf(BotorchTestCase):
647+
def test_isinstance_af(self):
648+
mm = MockModel(
649+
MockPosterior(
650+
mean=torch.rand(1, 1, device=self.device),
651+
variance=torch.ones(1, 1, device=self.device),
652+
)
653+
)
654+
acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0)
655+
self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement))
656+
self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound))
657+
wrapped_af = FixedFeatureAcquisitionFunction(
658+
acq_function=acq_func, d=2, columns=[1], values=[0.0]
659+
)
660+
# test base af class
661+
self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement))
662+
self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound))
663+
# test wrapper class
664+
self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction))
665+
self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction))
666+
667+
609668
class TestPruneInferiorPoints(BotorchTestCase):
610669
def test_prune_inferior_points(self):
611670
for dtype in (torch.float, torch.double):

0 commit comments

Comments
 (0)