4242from botorch .models .utils .assorted import get_task_value_remapping
4343from botorch .models .utils .gpytorch_modules import (
4444 get_covar_module_with_dim_scaled_prior ,
45+ get_gaussian_likelihood_with_lognormal_prior ,
4546 MIN_INFERRED_NOISE_LEVEL ,
4647)
4748from botorch .posteriors .multitask import MultitaskGPPosterior
5556from gpytorch .kernels .index_kernel import IndexKernel
5657from gpytorch .kernels .multitask_kernel import MultitaskKernel
5758from gpytorch .likelihoods .gaussian_likelihood import FixedNoiseGaussianLikelihood
58- from gpytorch .likelihoods .hadamard_gaussian_likelihood import HadamardGaussianLikelihood
5959from gpytorch .likelihoods .likelihood import Likelihood
6060from gpytorch .likelihoods .multitask_gaussian_likelihood import (
6161 MultitaskGaussianLikelihood ,
@@ -115,7 +115,6 @@ def __init__(
115115 all_tasks : list [int ] | None = None ,
116116 outcome_transform : OutcomeTransform | _DefaultType | None = DEFAULT ,
117117 input_transform : InputTransform | None = None ,
118- validate_task_values : bool = True ,
119118 ) -> None :
120119 r"""Multi-Task GP model using an ICM kernel.
121120
@@ -158,9 +157,6 @@ def __init__(
158157 instantiation of the model.
159158 input_transform: An input transform that is applied in the model's
160159 forward pass.
161- validate_task_values: If True, validate that the task values supplied in the
162- input are expected tasks values. If false, unexpected task values
163- will be mapped to the first output_task if supplied.
164160
165161 Example:
166162 >>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -193,7 +189,7 @@ def __init__(
193189 "This is not allowed as it will lead to errors during model training."
194190 )
195191 all_tasks = all_tasks or all_tasks_inferred
196- self .num_tasks = len (all_tasks_inferred )
192+ self .num_tasks = len (all_tasks )
197193 if outcome_transform == DEFAULT :
198194 outcome_transform = Standardize (m = 1 , batch_shape = train_X .shape [:- 2 ])
199195 if outcome_transform is not None :
@@ -212,20 +208,10 @@ def __init__(
212208 self ._output_tasks = output_tasks
213209 self ._num_outputs = len (output_tasks )
214210
211+ # TODO (T41270962): Support task-specific noise levels in likelihood
215212 if likelihood is None :
216213 if train_Yvar is None :
217- noise_prior = LogNormalPrior (loc = - 4.0 , scale = 1.0 )
218- likelihood = HadamardGaussianLikelihood (
219- num_tasks = self .num_tasks ,
220- batch_shape = torch .Size (),
221- noise_prior = noise_prior ,
222- noise_constraint = GreaterThan (
223- MIN_INFERRED_NOISE_LEVEL ,
224- transform = None ,
225- initial_value = noise_prior .mode ,
226- ),
227- task_feature_index = task_feature ,
228- )
214+ likelihood = get_gaussian_likelihood_with_lognormal_prior ()
229215 else :
230216 likelihood = FixedNoiseGaussianLikelihood (noise = train_Yvar .squeeze (- 1 ))
231217
@@ -263,60 +249,31 @@ def __init__(
263249
264250 self .covar_module = data_covar_module * task_covar_module
265251 task_mapper = get_task_value_remapping (
266- observed_task_values = torch .tensor (
267- all_tasks_inferred , dtype = torch .long , device = train_X .device
268- ),
269- all_task_values = torch .tensor (
270- sorted (all_tasks ), dtype = torch .long , device = train_X .device
252+ task_values = torch .tensor (
253+ all_tasks , dtype = torch .long , device = train_X .device
271254 ),
272255 dtype = train_X .dtype ,
273- default_task_value = None if output_tasks is None else output_tasks [0 ],
274256 )
275257 self .register_buffer ("_task_mapper" , task_mapper )
276- self ._expected_task_values = set (all_tasks_inferred )
258+ self ._expected_task_values = set (all_tasks )
277259 if input_transform is not None :
278260 self .input_transform = input_transform
279261 if outcome_transform is not None :
280262 self .outcome_transform = outcome_transform
281- self ._validate_task_values = validate_task_values
282263 self .to (train_X )
283264
284265 def _map_tasks (self , task_values : Tensor ) -> Tensor :
285- """Map raw task values to the task indices used by the model .
266+ """Map task values to contiguous integers using the task mapper .
286267
287268 Args:
288- task_values: A tensor of task values .
269+ task_values: A tensor of task indices to be mapped .
289270
290271 Returns:
291- A tensor of task indices with the same shape as the input
292- tensor.
272+ A tensor of mapped task indices.
293273 """
294- long_task_values = task_values .long ()
295- if self ._validate_task_values :
296- if self ._task_mapper is None :
297- if not (
298- torch .all (0 <= task_values )
299- and torch .all (task_values < self .num_tasks )
300- ):
301- raise ValueError (
302- "Expected all task features in `X` to be between 0 and "
303- f"self.num_tasks - 1. Got { task_values } ."
304- )
305- else :
306- unexpected_task_values = set (
307- long_task_values .unique ().tolist ()
308- ).difference (self ._expected_task_values )
309- if len (unexpected_task_values ) > 0 :
310- raise ValueError (
311- "Received invalid raw task values. Expected raw value to be in"
312- f" { self ._expected_task_values } , but got unexpected task"
313- f" values: { unexpected_task_values } ."
314- )
315- task_values = self ._task_mapper [long_task_values ]
316- elif self ._task_mapper is not None :
317- task_values = self ._task_mapper [long_task_values ]
318-
319- return task_values
274+ if self ._task_mapper is None :
275+ return task_values .to (dtype = self .train_targets .dtype )
276+ return self ._task_mapper [task_values ].to (dtype = self .train_targets .dtype )
320277
321278 def _split_inputs (self , x : Tensor ) -> tuple [Tensor , Tensor , Tensor ]:
322279 r"""Extracts features before task feature, task indices, and features after
@@ -330,7 +287,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
330287 3-element tuple containing
331288
332289 - A `q x d` or `b x q x d` tensor with features before the task feature
333- - A `q` or `b x q x 1 ` tensor with mapped task indices
290+ - A `q` or `b x q` tensor with mapped task indices
334291 - A `q x d` or `b x q x d` tensor with features after the task feature
335292 """
336293 batch_shape = x .shape [:- 2 ]
@@ -370,7 +327,7 @@ def get_all_tasks(
370327 raise ValueError (f"Must have that -{ d } <= task_feature <= { d } " )
371328 task_feature = task_feature % (d + 1 )
372329 all_tasks = (
373- train_X [..., task_feature ].to (dtype = torch .long ). unique ( sorted = True ).tolist ()
330+ train_X [..., task_feature ].unique ( sorted = True ). to (dtype = torch .long ).tolist ()
374331 )
375332 return all_tasks , task_feature , d
376333
0 commit comments