DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.4k stars 130 forks source link

Failure to save NeuralODE when using adjoint sensitivity #122

Closed Bawaw closed 2 years ago

Bawaw commented 2 years ago

Describe the bug

Hey DiffEqML team,

I just encountered this error when attempting to save a trained model:

Traceback (most recent call last):
  File "save_fail.py", line 36, in <module>
    torch.save(model, 'save_test.pt')
  File "/home/bawaw/.local/lib/python3.8/site-packages/torch/serialization.py", line 379, in save
    _save(obj, opened_zipfile, pickle_module, pickle_protocol)
  File "/home/bawaw/.local/lib/python3.8/site-packages/torch/serialization.py", line 484, in _save
    pickler.dump(obj)
AttributeError: Can't pickle local object '_gather_odefunc_adjoint.<locals>._ODEProblemFunc'

The problem only seems to occur when using the adjoint sensitivity, autograd works as expected.

Best, Balder

Step to Reproduce

import torch
import pytorch_lightning as pl
from torchdyn.core import NeuralODE

class Learner(pl.LightningModule):
    def __init__(self, model:torch.nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch[0]
        _, z = self.model(x, torch.linspace(0, 1, 100))
        loss = z.abs().mean()
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.01)

    def train_dataloader(self):
        dataset = torch.utils.data.TensorDataset(torch.randn(10, 1))
        return torch.utils.data.DataLoader(dataset)

f = torch.nn.Sequential(
        torch.nn.Linear(1, 16),
        torch.nn.Tanh(),
        torch.nn.Linear(16, 1)
    )

model = NeuralODE(f, sensitivity='adjoint')
learn = Learner(model)
trainer = pl.Trainer(max_epochs=1, gpus=1)
#trainer.fit(learn)
torch.save(model, 'save_test.pt')
model = torch.load('save_test.pt')

Expected behavior

Model should be saved in the file 'save_test.pt', similar to the way it behaves when using the autograd sensitivity.

Zymrael commented 2 years ago

Thank you for reporting the issue. This is actually something that is known internally and will be fixed in an upcoming release.

If you need a quick fix, this should be resolved by setting the .autograd_function to None or any other value that can be pickled.

alokwarey commented 2 years ago

Model save still doesn't work in version 1.0.2. Any plans to fix this bug?

fedebotu commented 2 years ago

We fixed the issue by updating release 1.0.2 :)

149