Skip to content

Commit 0425d46

Browse files
esantorellafacebook-github-bot
authored andcommitted
Don't allow unused **kwargs in input_constructors except for a defined set of exceptions (meta-pytorch#1872)
Summary: X-link: facebook/Ax#1772 Pull Request resolved: meta-pytorch#1872 [x] Remove unused arguments from input constructors and related functions. The idea is especially not to let unused keyword arguments disappear into `**kwargs` and be silently ignored [x] add arguments to some input constructors so they don't need any `**kwargs` [x] Add a decorator that ensures that each input constructor can accept a certain set of keyword arguments, even if those are not used are the constructor, while still erroring on [ ] Prevent arguments from having different defaults in the input constructors as in acquisition functions Reviewed By: lena-kashtelyan Differential Revision: D46519588 fbshipit-source-id: 862fab3fca2460e04462b4b408b6bc3431baa0ff
1 parent fe122b0 commit 0425d46

File tree

8 files changed

+396
-300
lines changed

8 files changed

+396
-300
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 176 additions & 181 deletions
Large diffs are not rendered by default.

botorch/acquisition/joint_entropy_search.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import warnings
2626
from math import log, pi
2727

28-
from typing import Any, Optional
28+
from typing import Optional
2929

3030
import torch
3131
from botorch import settings
@@ -78,7 +78,6 @@ def __init__(
7878
estimation_type: str = "LB",
7979
maximize: bool = True,
8080
num_samples: int = 64,
81-
**kwargs: Any,
8281
) -> None:
8382
r"""Joint entropy search acquisition function.
8483

botorch/acquisition/knowledge_gradient.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def __init__(
7474
inner_sampler: Optional[MCSampler] = None,
7575
X_pending: Optional[Tensor] = None,
7676
current_value: Optional[Tensor] = None,
77-
**kwargs: Any,
7877
) -> None:
7978
r"""q-Knowledge Gradient (one-shot optimization).
8079
@@ -330,7 +329,6 @@ def __init__(
330329
expand: Callable[[Tensor], Tensor] = lambda X: X,
331330
valfunc_cls: Optional[Type[AcquisitionFunction]] = None,
332331
valfunc_argfac: Optional[Callable[[Model], Dict[str, Any]]] = None,
333-
**kwargs: Any,
334332
) -> None:
335333
r"""Multi-Fidelity q-Knowledge Gradient (one-shot optimization).
336334

botorch/acquisition/max_value_entropy_search.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,6 @@ def __init__(
326326
maximize: bool = True,
327327
X_pending: Optional[Tensor] = None,
328328
train_inputs: Optional[Tensor] = None,
329-
**kwargs: Any,
330329
) -> None:
331330
r"""Single-outcome max-value entropy search acquisition function.
332331
@@ -697,7 +696,6 @@ def __init__(
697696
cost_aware_utility: Optional[CostAwareUtility] = None,
698697
project: Callable[[Tensor], Tensor] = lambda X: X,
699698
expand: Callable[[Tensor], Tensor] = lambda X: X,
700-
**kwargs: Any,
701699
) -> None:
702700
r"""Single-outcome max-value entropy search acquisition function.
703701

botorch/acquisition/monte_carlo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from abc import ABC, abstractmethod
2525
from copy import deepcopy
2626
from functools import partial
27-
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union
27+
from typing import Callable, List, Optional, Protocol, Tuple, Union
2828

2929
import torch
3030
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
@@ -351,7 +351,6 @@ def __init__(
351351
X_pending: Optional[Tensor] = None,
352352
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
353353
eta: Union[Tensor, float] = 1e-3,
354-
**kwargs: Any,
355354
) -> None:
356355
r"""q-Expected Improvement.
357356
@@ -434,7 +433,7 @@ def __init__(
434433
cache_root: bool = True,
435434
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
436435
eta: Union[Tensor, float] = 1e-3,
437-
**kwargs: Any,
436+
marginalize_dim: Optional[int] = None,
438437
) -> None:
439438
r"""q-Noisy Expected Improvement.
440439
@@ -469,6 +468,7 @@ def __init__(
469468
eta: Temperature parameter(s) governing the smoothness of the sigmoid
470469
approximation to the constraint indicators. For more details, on this
471470
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
471+
marginalize_dim: The dimension to marginalize over.
472472
473473
TODO: similar to qNEHVI, when we are using sequential greedy candidate
474474
selection, we could incorporate pending points X_baseline and compute
@@ -491,7 +491,7 @@ def __init__(
491491
X=X_baseline,
492492
objective=objective,
493493
posterior_transform=posterior_transform,
494-
marginalize_dim=kwargs.get("marginalize_dim"),
494+
marginalize_dim=marginalize_dim,
495495
)
496496
self.register_buffer("X_baseline", X_baseline)
497497
# registering buffers for _get_samples_and_objectives in the next `if` block

botorch/acquisition/multi_objective/monte_carlo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from abc import ABC, abstractmethod
2828
from copy import deepcopy
2929
from itertools import combinations
30-
from typing import Any, Callable, List, Optional, Union
30+
from typing import Callable, List, Optional, Union
3131

3232
import torch
3333
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
@@ -373,7 +373,7 @@ def __init__(
373373
max_iep: int = 0,
374374
incremental_nehvi: bool = True,
375375
cache_root: bool = True,
376-
**kwargs: Any,
376+
marginalize_dim: Optional[int] = None,
377377
) -> None:
378378
r"""q-Noisy Expected Hypervolume Improvement supporting m>=2 outcomes.
379379
@@ -466,7 +466,7 @@ def __init__(
466466
objective=objective,
467467
constraints=constraints,
468468
ref_point=ref_point,
469-
marginalize_dim=kwargs.get("marginalize_dim"),
469+
marginalize_dim=marginalize_dim,
470470
)
471471
self.register_buffer("ref_point", ref_point)
472472
self.alpha = alpha

0 commit comments

Comments
 (0)