-
Notifications
You must be signed in to change notification settings - Fork 452
Closed
Description
I'm implementing a BO loop with feasibility constraints along the lines of https://botorch.org/tutorials/constrained_multi_objective_bo
However, in my case evaluations of the feasibility constraint are discrete (0, 1) for which a GP model with a binomial likelihood seems to a suitable approach.
import gpytorch
class GPClassificationModel(gpytorch.models.ApproximateGP):
def __init__(self, train_x):
variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(train_x.size(0))
variational_strategy = gpytorch.variational.VariationalStrategy(
self, train_x, variational_distribution, learn_inducing_locations=True
)
super(GPClassificationModel, self).__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
model = GPClassificationModel(train_x)
likelihood = gpytorch.likelihoods.BernoulliLikelihood()
mll = gpytorch.mlls.VariationalELBO(likelihood, model, len(train_y), combine_terms=False)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model.train()
likelihood.train()
for i in range(400):
optimizer.zero_grad()
output = model(train_x)
log_lik, kl_div, log_prior = mll(output, train_y)
loss = -(log_lik - kl_div + log_prior)
loss.backward()
optimizer.step()Now I'm wondering how to feed this model together with a SingleTaskGP to the acquisition function.
Do I have to base my GPClassificationModel on ApproximateGP or can I simply combine it in ModelListGPyTorchModel?
Metadata
Metadata
Assignees
Labels
No labels