rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

adjoint method breaks after reaching certain performance #211

Open Lancial opened 2 years ago

Lancial commented 2 years ago

Hi,

I've been using this library for a image-based flow estimation task. The way I do it is using an ode solver to solve for an evolving spatial transformation by passing an inital zero flow and some additional feature to odeint/adjoint_odeint. I used a convolutional based neural network. My code looks roughly like this:

class model(nn.Module):
    def __init__():
        self.encoder = # a image feature extractor
        self.odefunc = CNN()

   def forward(x):
        x = self.encoder(x)
        ode_x = torch.cat([x, zero_flow], 1) 
        ode_y = odeint_adj(self.odefunc, ode_x)
        flow = ode_y[:, :, -flow_dim:]
        return flow

class CNN(nn.module):
    def forward(t, x):
        delta_flow = self.layers(x) # run through all layers
        return torch.cat([torch.zeros(), delta_flow], 1) # match input dimension, features remain static

My network converges with non-adjoint method. But when I was using adjoint method, the model converged initalially but losses would always explode after reaching certain performance(The network produces quite accurate result before it breaks). The loss functions I used are typical image similarity loss and a regularization loss. I used a fixed-step euler solver. Do you know what could be the reasons for this? I highly appreaciate any suggestion!