odlgroup / odl

Operator Discretization Library https://odlgroup.github.io/odl/
Mozilla Public License 2.0
369 stars 105 forks source link

Torch OperatorAsModule + RayTransform integration for backpropogation #1523

Closed Zakobian closed 5 years ago

Zakobian commented 5 years ago

Hello, I have been trying to do gradient descent using a simple MNIST image and a RayTransform forward operator, however the 'OperatorAsModule' doesn't seem to be able to propogate backwards after having done it once. Following is an example:

import odl
from odl.contrib.torch import OperatorAsModule
import torch
from torch.autograd import Variable
from torch import optim
import torchvision
from torchvision import datasets, transforms

trafo = transforms.Compose([transforms.ToTensor()])
dset_train = datasets.MNIST('./', train=True, download=True, transform=trafo)
train_loader = torch.utils.data.DataLoader(dset_train, batch_size=50, shuffle=True)

space = odl.uniform_discr([-14, -14], [14, 14], [28, 28], dtype='float32')
geometry = odl.tomo.parallel_beam_geometry(space, num_angles=5)
fwd_op = odl.tomo.RayTransform(space, geometry)
fbp_op = odl.tomo.fbp_op(fwd_op)
fwd_op_mod = OperatorAsModule(fwd_op)
fbp_op_mod = OperatorAsModule(fbp_op)

def generate_data(images):
    torch.manual_seed(123)
    data = fwd_op_mod(images)
    data += Variable(torch.randn(data.shape)).type_as(data)
    return data
def f(output):
    yield output
for i, (images, _) in enumerate(test_loader):
    if i == 1:
        break

test_data=generate_data(Variable(images))
guess=fbp_op_mod(test_data)[0:1,0:1].clone()
truth=Variable(images)[0:1,0:1]

guess=torch.nn.Parameter(guess)
optimizer = optim.SGD(f(guess), lr=1e-2)
i=0
while(i<200):
    optimizer.zero_grad()
    i+=1
    loss=torch.dist(fwd_op_mod(guess),test_data,p=2)
    print(loss.item())
    loss.backward()
    print(guess.grad)
    optimizer.step()

After 1 loop it is no longer able to produce a gradient. However adding

    fwd_op_mod = OperatorAsModule(fwd_op)

inside the loop fixes the problem and the gradient exists. And it is able to do gradient descent successfully. I have tried using different optimizers/instead of using optimizers just using torch.autograd itself, but the issue seems to arise nonetheless.

Unless there is an issue with the way I handle things, it seems that after 1 forward pass through the module, it can no longer produce gradients, i.e. if I do

    fwd_op_mod(guess)

before the loop, the gradient stays at None.

I use odl 0.7.0, torch 1.1.0, astra 1.8.3 Plase let me know, if there is any extra information required.

kohr-h commented 5 years ago

Hey @Zakobian, sorry for the delay. I recently noticed that I had gotten the torch API slightly backwards the first time. In fact the torch documentation says that an autograd.Function can only be used once (for graph building), so that may be the issue that you encounter.

There's an open PR #1516 with a better API (and updated examples) where the torch API is used correctly. It will likely be merged soon, but if you want to give it a try, just check out my torch_better_api branch.