Skip to content

Failure to save NeuralODE when using adjoint sensitivity #122

@Bawaw

Description

@Bawaw

Describe the bug

Hey DiffEqML team,

I just encountered this error when attempting to save a trained model:

Traceback (most recent call last):
  File "save_fail.py", line 36, in <module>
    torch.save(model, 'save_test.pt')
  File "/home/bawaw/.local/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/home/bawaw/.local/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
    pickler.dump(obj)
AttributeError: Can't pickle local object '_gather_odefunc_adjoint.<locals>._ODEProblemFunc'

The problem only seems to occur when using the adjoint sensitivity, autograd works as expected.

Best,
Balder

Step to Reproduce

import torch
import pytorch_lightning as pl
from torchdyn.core import NeuralODE

class Learner(pl.LightningModule):
    def __init__(self, model:torch.nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch[0]
        _, z = self.model(x, torch.linspace(0, 1, 100))
        loss = z.abs().mean()
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.01)

    def train_dataloader(self):
        dataset = torch.utils.data.TensorDataset(torch.randn(10, 1))
        return torch.utils.data.DataLoader(dataset)

f = torch.nn.Sequential(
        torch.nn.Linear(1, 16),
        torch.nn.Tanh(),
        torch.nn.Linear(16, 1)
    )

model = NeuralODE(f, sensitivity='adjoint')
learn = Learner(model)
trainer = pl.Trainer(max_epochs=1, gpus=1)
#trainer.fit(learn)
torch.save(model, 'save_test.pt')
model = torch.load('save_test.pt')

Expected behavior

Model should be saved in the file 'save_test.pt', similar to the way it behaves when using the autograd sensitivity.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions