diff --git a/src/anomalib/models/image/dsr/lightning_model.py b/src/anomalib/models/image/dsr/lightning_model.py index 29f3ad9454..e9eb4d2693 100644 --- a/src/anomalib/models/image/dsr/lightning_model.py +++ b/src/anomalib/models/image/dsr/lightning_model.py @@ -94,7 +94,7 @@ def configure_optimizers( def on_train_start(self) -> None: """Load pretrained weights of the discrete model when starting training.""" ckpt: Path = self.prepare_pretrained_model() - self.model.load_pretrained_discrete_model_weights(ckpt) + self.model.load_pretrained_discrete_model_weights(ckpt, self.device) def on_train_epoch_start(self) -> None: """Display a message when starting to train the upsampling module.""" diff --git a/src/anomalib/models/image/dsr/torch_model.py b/src/anomalib/models/image/dsr/torch_model.py index 416e1b7451..e5f11a6618 100644 --- a/src/anomalib/models/image/dsr/torch_model.py +++ b/src/anomalib/models/image/dsr/torch_model.py @@ -81,9 +81,9 @@ def __init__( for parameters in self.discrete_latent_model.parameters(): parameters.requires_grad = False - def load_pretrained_discrete_model_weights(self, ckpt: Path) -> None: + def load_pretrained_discrete_model_weights(self, ckpt: Path, device: torch.device | str | None = None) -> None: """Load pre-trained model weights.""" - self.discrete_latent_model.load_state_dict(torch.load(ckpt)) + self.discrete_latent_model.load_state_dict(torch.load(ckpt, map_location=device)) def forward( self,