Skip to content

Commit f61aaae

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
SingleTaskGP with empty inputs [Do not land] (#1229)
Summary: Pull Request resolved: #1229 See #1217 Differential Revision: D36490304 fbshipit-source-id: 3495ee4b3de75c2764d6cc0b4ad67e2293e4e341
1 parent 3805046 commit f61aaae

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

botorch/fit.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,13 @@ def fit_gpytorch_model(
115115
return fit_gpytorch_model(
116116
mll=mll, optimizer=optimizer, sequential=False, max_retries=max_retries
117117
)
118-
# retry with random samples from the priors upon failure
118+
# Skip training if there's no training data.
119+
if mll.model.train_targets.numel() == 0:
120+
logging.log(
121+
logging.DEBUG, "Skipping model training due to empty training data."
122+
)
123+
return mll
124+
# Retry with random samples from the priors upon failure.
119125
mll.train()
120126
original_state_dict = deepcopy(mll.model.state_dict())
121127
retry = 0

test/models/test_gp_regression.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838

3939
class TestSingleTaskGP(BotorchTestCase):
40+
model_class = SingleTaskGP
41+
4042
def _get_model_and_data(
4143
self,
4244
batch_shape,
@@ -372,8 +374,33 @@ def test_set_transformed_inputs(self):
372374
tf_X = intf(X)
373375
self.assertEqual(X.shape, tf_X.shape)
374376

377+
def test_empty_inputs(self):
378+
empty_inputs = torch.ones(0, 2)
379+
kwargs = {
380+
"train_X": empty_inputs,
381+
"train_Y": empty_inputs,
382+
}
383+
if self.model_class is not SingleTaskGP:
384+
kwargs["train_Yvar"] = empty_inputs
385+
model = self.model_class(**kwargs)
386+
mll = ExactMarginalLogLikelihood(model.likelihood, model)
387+
fit_gpytorch_model(mll)
388+
X_prediction = torch.rand(3, 4, 2)
389+
with torch.no_grad():
390+
posterior = model.posterior(X_prediction)
391+
samples = posterior.rsample(sample_shape=torch.Size([5]))
392+
self.assertEqual(samples.shape, torch.Size([5, 3, 4, 2]))
393+
expected_mean = torch.zeros_like(posterior.mean)
394+
expected_var = torch.full_like(
395+
posterior.variance, fill_value=model.covar_module.outputscale[0].detach()
396+
)
397+
assert torch.equal(posterior.mean, expected_mean)
398+
assert torch.equal(posterior.variance, expected_var)
399+
375400

376401
class TestFixedNoiseGP(TestSingleTaskGP):
402+
model_class = FixedNoiseGP
403+
377404
def _get_model_and_data(
378405
self,
379406
batch_shape,
@@ -436,6 +463,8 @@ def test_construct_inputs(self):
436463

437464

438465
class TestHeteroskedasticSingleTaskGP(TestSingleTaskGP):
466+
model_class = HeteroskedasticSingleTaskGP
467+
439468
def _get_model_and_data(
440469
self, batch_shape, m, outcome_transform=None, input_transform=None, **tkwargs
441470
):

0 commit comments

Comments
 (0)