matteo-ronchetti / torch-radon

Computational Tomography in PyTorch
https://torch-radon.readthedocs.io
GNU General Public License v3.0
219 stars 45 forks source link

backprop issues after forward projection #16

Closed HJ-harry closed 3 years ago

HJ-harry commented 4 years ago

Hi,

First of all thanks for the great library. It's very intuitive to use :+1:

However, I have an issue with backward() in torch not working after the forward projection, and some permutations to the result, due to the contiguousy issue.

The task that I'm trying to do here deals with 3D objects, and I guess I can also do that since torch-radon supports batch dimensions. If I have a 3D object of shape (batch, channel, h, w, d), where h refers to height, w refers to width, and d refers to depth, I have to permute the dimensions before computing forward projection, and after the projection, the sinogram (or 2D projections in this case) will have the shape (d, N, det) where N is the number of projection angles, det is the number of detector pixels. For me to see this as multiple 2D projections, I need to permute the dimensions again so that it has the shape (batch, N, det, d). I have written a code which produces the same error I am encountering.

Code

    device = 'cuda:0'

    img_size = 128
    n_angles = 3
    geometry = 'parallel'

    # Parallel beam geometry
    if geometry == 'parallel':
        angles = np.linspace(0, np.pi, n_angles, endpoint=False)
        radon = Radon(img_size, angles, clip_to_circle=False)

    l1_loss = nn.L1Loss(reduction='mean')
    #########################################################
    # Case 1: No ERROR
    #########################################################
    # Random 3D object in pytorch dimensions - (b x c x h x w x d)
    obj = torch.randn([1, 1, 128, 128, 128], requires_grad=True).to(device)

    target = torch.ones([1, 3, 128, 128]).to(device)
    target = proj_bpxy_to_ypx(target)

    sino = radon.forward(vol_bchwz_to_zhw(obj))
    loss = l1_loss(sino, target)
    loss.backward()

    #########################################################
    # Case 2: No ERROR
    #########################################################
    # Random 3D object in pytorch dimensions - (b x c x h x w x d)
    obj = torch.randn([1, 1, 128, 128, 128], requires_grad=True).to(device)

    target = torch.ones([1, 3, 128, 128]).to(device)

    sino = proj_ypx_to_bpxy(radon.forward(vol_bchwz_to_zhw(obj)))
    loss = l1_loss(sino, target)
    loss.backward()

the helper functions are defined as below

def proj_bpxy_to_ypx(tensor):
    """
    Ex. 3-view
    From ~ Input projection     : [3(p), 128(x), 128(y)] or [1(b), 3(p), 128(x), 128(y)]
    To   ~ TorchRadon projction : [128(y), 3(p), 128(x)]
    """
    assert isinstance(tensor, torch.Tensor), 'input must be type torch tensor'
    if tensor.dim() == 4:
        tensor = tensor.squeeze()
    return tensor.permute(2, 0, 1).contiguous()

def proj_ypx_to_bpxy(tensor):
    """
    Ex. 3-view
    TorchRadon projction : [128(y), 3(p), 128(x)]
    Input projection     : [1(b), 3(p), 128(x), 128(y)]
    """
    assert isinstance(tensor, torch.Tensor), 'input must be type torch tensor'
    assert tensor.dim() == 3, f'Dimension must be e.g.[128(x), 3(p), 128(y)]. Received {tensor.shape}'
    y, p, x = tensor.shape
    return tensor.permute(1, 2, 0).contiguous().view(1, p, x, y).contiguous()

def vol_zhw_to_bchwz(tensor):
    """
    Ex.
    TorchRadon BP : [z, h, w]
    Volume shape for 3D Unet Generator: [b, c, h, w, z]
    """
    assert isinstance(tensor, torch.Tensor), 'input must be type torch tensor'
    assert tensor.dim() == 3, f'Dimension must be e.g.[z, h, w]. Received {tensor.shape}'
    z, h, w = tensor.shape
    return tensor.permute(1, 2, 0).contiguous().view(1, 1, h, w, z)

def vol_bchwz_to_zhw(tensor):
    """
    Ex.
    Volume shape for 3D Unet Generator: [b, c, h, w, z]
    TorchRadon BP : [z, h, w]
    """
    assert isinstance(tensor, torch.Tensor), 'input must be type torch tensor'
    assert tensor.dim() == 5, f'Dimension must be e.g.[b, c, h, w, z]. Received {tensor.shape}'
    b, c, h, w, z = tensor.shape
    return tensor.squeeze().contiguous().view(z, h, w)

For the Case 1, there is no error. However, if I run Case 2, I get the following error message.

RuntimeError: x must be contiguous

I guess you usually solve this error with contiguous(), https://stackoverflow.com/questions/48915810/pytorch-contiguous and I've tried putting this to any possible places before and after transpose operations, but it still produce errors. Am I doing something wrong? Could you help me fix this issue?

Thanks a lot in advance

matteo-ronchetti commented 4 years ago

Thanks for the appreciation and for the detailed bug report. The python "frontend" of TorchRadon takes care of making inputs contiguous, the problem is that this functionality was not implemented into the gradient computation handlers. In your situation the gradient is non-contiguous (this is why you weren't able to solve the problem by adding contiguous()) and this breaks the library.

I have pushed a fix to your problem, newly compiled binaries should be available after this build finishes (usually it takes 1 hour). Please wait for the build to finish and then install the library again.

Thanks for the report, Matteo

HJ-harry commented 3 years ago

Thank you very much for the quick reply and the fix! :)

Now it works like a charm. Have a great day Hyungjin