-
Notifications
You must be signed in to change notification settings - Fork 134
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Hey DiffEqML team,
I just encountered this error when attempting to save a trained model:
Traceback (most recent call last):
File "save_fail.py", line 36, in <module>
torch.save(model, 'save_test.pt')
File "/home/bawaw/.local/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/home/bawaw/.local/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
pickler.dump(obj)
AttributeError: Can't pickle local object '_gather_odefunc_adjoint.<locals>._ODEProblemFunc'
The problem only seems to occur when using the adjoint sensitivity, autograd works as expected.
Best,
Balder
Step to Reproduce
import torch
import pytorch_lightning as pl
from torchdyn.core import NeuralODE
class Learner(pl.LightningModule):
def __init__(self, model:torch.nn.Module):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x = batch[0]
_, z = self.model(x, torch.linspace(0, 1, 100))
loss = z.abs().mean()
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.01)
def train_dataloader(self):
dataset = torch.utils.data.TensorDataset(torch.randn(10, 1))
return torch.utils.data.DataLoader(dataset)
f = torch.nn.Sequential(
torch.nn.Linear(1, 16),
torch.nn.Tanh(),
torch.nn.Linear(16, 1)
)
model = NeuralODE(f, sensitivity='adjoint')
learn = Learner(model)
trainer = pl.Trainer(max_epochs=1, gpus=1)
#trainer.fit(learn)
torch.save(model, 'save_test.pt')
model = torch.load('save_test.pt')
Expected behavior
Model should be saved in the file 'save_test.pt', similar to the way it behaves when using the autograd sensitivity.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working