DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.33k stars 124 forks source link

Is there a way to use multiple optimizers with a NerualODE model? #170

Open mareikethies opened 1 year ago

mareikethies commented 1 year ago

Hi, I noticed that my code fails when using two instead of only one optimizer for different groups of parameters. The minimal example below is from your repo, but I altered the configure_optimizers() method in the lightning module. This breaks the code and it fails with:

RuntimeError: One of the differentiated Tensors does not require grad

Is this expected? Is there a way to use multiple optimizers in a different way?

The minimal example is:

from torchdyn.datasets import *
from torchdyn.core import NeuralODE
import torch
import torch.utils.data as data

import torch.nn as nn
import pytorch_lightning as pl

device = torch.device("cuda:0")

class Learner(pl.LightningModule):
    def __init__(self, t_span: torch.Tensor, model: nn.Module):
        super().__init__()
        self.model, self.t_span = model, t_span

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx, optimizer_idx):
        x, y = batch
        t_eval, y_hat = self.model(x, self.t_span)
        y_hat = y_hat[-1]  # select last point of solution trajectory
        loss = nn.CrossEntropyLoss()(y_hat, y)
        return {'loss': loss}

    def configure_optimizers(self):
        # return torch.optim.Adam(self.model.parameters(), lr=0.01)
        optimizer1 = torch.optim.Adam(list(self.model.parameters())[:2], lr=0.01)
        optimizer2 = torch.optim.Adam(list(self.model.parameters())[2:], lr=0.01)
        return ({'optimizer': optimizer1},
                {'optimizer': optimizer2})

    def train_dataloader(self):
        return trainloader

d = ToyDataset()
X, yn = d.generate(n_samples=512, dataset_type='moons', noise=.1)
X_train = torch.Tensor(X).to(device)
y_train = torch.LongTensor(yn.long()).to(device)
train = data.TensorDataset(X_train, y_train)
trainloader = data.DataLoader(train, batch_size=len(X), shuffle=True)

f = nn.Sequential(
        nn.Linear(2, 64),
        nn.Tanh(),
        nn.Linear(64, 2))

t_span = torch.linspace(0, 1, 2)

model = NeuralODE(f, sensitivity='adjoint', solver='tsit5', interpolator=None, atol=1e-3, rtol=1e-3).to(device)

learn = Learner(t_span, model)
trainer = pl.Trainer(min_epochs=200, max_epochs=250)
trainer.fit(learn)