guopengf / ReconFormer

ReconFormer: Accelerated MRI Reconstruction Using Recurrent Transformer
https://arxiv.org/abs/2201.09376
MIT License
58 stars 10 forks source link

Fix data-consistency module #9

Open bilalkabas opened 3 months ago

bilalkabas commented 3 months ago

Summary

Fixes #10

This PR addresses the problem with data-consistency module and 2D Fourier transform functions fft2, and ifft2. The data-consistency module has been updated, fft2c and ifft2c functions are added to transforms.py.

Problem definition

The below data-consistency module does not work:

https://github.com/guopengf/ReconFormer/blob/e2e0d5e6e58e04ad1c77a1151e63cf56bec21fb1/models/Recurrent_Transformer.py#L13-L55

This is due to some errors in fft2 and ifft2 functions in transforms.py:

https://github.com/guopengf/ReconFormer/blob/e2e0d5e6e58e04ad1c77a1151e63cf56bec21fb1/data/transforms.py#L73-L107

To Reproduce


import torch
from backbones.reconformer.reconformer import DataConsistencyInKspace

resolution = 320
device = 'cuda:0'

x = torch.randn((1, 2, resolution, resolution)).to(device)
k0 = torch.randn((1, 2, resolution, resolution)).to(device)
mask = torch.randn((1, 1, resolution, resolution)).to(device)

dc = DataConsistencyInKspace()
out = dc(x, k0, mask)

print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
     46 k0 = k0.permute(0, 2, 3, 1)
     47 mask = mask.permute(0, 2, 3, 1)
...
--> 122 data = torch.fft.fft(data, 2, normalized=normalized)
    123 data = fftshift(data, dim=(-3, -2))
    124 return data

TypeError: fft_fft() got an unexpected keyword argument 'normalized'