facebookresearch / fastMRI

A large-scale dataset of both raw MRI measurements and clinical MRI images.
https://fastmri.org
MIT License
1.3k stars 372 forks source link

Varying Inputsize leads to NaNs during training #139

Closed tobiasvitt closed 3 years ago

tobiasvitt commented 3 years ago

Hey, I'm using the banding_removal codebase and run the scripts/pretrain.py script with the same configuration. As a training dataset I use the fastmri multicoil knee dataset.

I currently have two scenarios leading to NaNs in some ConvBlocks and in the losses, both are caused by the same underlaying problem.

First, running with batch_size > 1 on the non-distributed setup: After a few steps in the first epoch the losses get NaN, because the input shape in that step changed (i.e. from [15, 640, 372, 2], to [15, 640, 322, 2]). I assume the different input size (particularly smaller ones) cause that some NaN values sneak in and propagate, leading to NaN losses and everything being NaN.

Second, in distributed mode spawn_dist.run(args)with batch_size = 1: RuntimeError: stack expects each tensor to be equal size, but got [15, 640, 372, 2] at entry 0 and [15, 322, 640, 2] at entry 1 The issue here is that the different input sizes cant be stacked.

So the underlaying issue is, that the training data has varying input sizes.

I thought that the resize_type argument in args.py by default ('crop') already handles this issue, but as I saw it only affects the display_data.

How can I fix this issue? I thought that I could somehow pad each input in the KSpaceDataTransform but that didn't work. Somehow __call__ is not executed so I can't apply a transformation.

adefazio commented 3 years ago

For our development we used large models for which it wasn't useful to use a batch-size larger than 1 on each GPU, in fact it would result in out-of-memory issues. The code-base is tailored to this situation, and won't work with larger batch-sizes. I would recommend you just use batch-size 1. It should give good performance and GPU utilization if you are using large models. It's possible to adapt the code to larger batches, you can potentially map to images, pad then map back, but you are likely to run into unexpected issues. The KSpace transform is the place to do that, but there is probably other places in the code-base that will break as well. I'm surprised the call method is not being called, I would expect it to.

I can help debug the issue you see in the batch size 1 case, can you provide a full stack trace? Its possible something broke (pytorch change?) since I finished the paper.

tobiasvitt commented 3 years ago

Hey Aaron, thank you for your response and the offer to help out, below you can find more information about my issue.

My requirements.txt (python 3.8) looks like this:

Pillow == 8.2.0
apex == 0.9.10.dev0
clearml == 1.0.2
h5py == 3.2.1
ismrmrd == 1.7.3
numpy == 1.20.3
pytest == 6.2.4
runstats == 1.8.0
scikit_image == 0.18.0
scipy == 1.6.3
tensorboardX == 2.2
torch == 1.7.1
torchvision == 0.8.2

I used torch 1.7.1, since this code-base is still using torch.fft(…) which is deprecated now in version > 1.7.1. The only change I did is replacing skimage.measure.compare_psnr by skimage.metrics.peak_signal_noise_ratio which was renamed and should behave similar.

I run the configuration of scripts/pretrain.py on the full fastMRI knee dataset. I use batch size 1 as defined in the config in pretrain.py and have two Nvidia Quadro RTX 6000/8000 GPUs attached. I use line 66 of pretrain.py with spawn_dist.run(config) instead of line 67 run.run(config).

BATCH_SIZE = 1

When starting training, in Epoch 0 after 40-50 samples I get the following behaviour:

{'train_loss': tensor(0.1727, device='cuda:1', grad_fn=<AddBackward0>), 'ssim_loss': tensor(0.1727, device='cuda:1', grad_fn=<RsubBackward1>), 'l1_loss': tensor(9.5153e-06, device='cuda:1', grad_fn=<L1LossBackward>)}
torch.Size([1, 15, 372, 640, 2])
{'train_loss': tensor(0.4003, device='cuda:1', grad_fn=<AddBackward0>), 'ssim_loss': tensor(0.4003, device='cuda:1', grad_fn=<RsubBackward1>), 'l1_loss': tensor(6.7909e-06, device='cuda:1', grad_fn=<L1LossBackward>)}
torch.Size([1, 15, 640, 372, 2])
{'train_loss': tensor(0.1621, device='cuda:1', grad_fn=<AddBackward0>), 'ssim_loss': tensor(0.1621, device='cuda:1', grad_fn=<RsubBackward1>), 'l1_loss': tensor(5.0518e-06, device='cuda:1', grad_fn=<L1LossBackward>)}
2021-05-21 21:40:48
torch.Size([1, 15, 372, 640, 2])
{'train_loss': tensor(0.4247, device='cuda:0', grad_fn=<AddBackward0>), 'ssim_loss': tensor(0.4247, device='cuda:0', grad_fn=<RsubBackward1>), 'l1_loss': tensor(8.7041e-06, device='cuda:0', grad_fn=<L1LossBackward>)}
torch.Size([1, 15, 640, 372, 2])
{'train_loss': tensor(0.1544, device='cuda:0', grad_fn=<AddBackward0>), 'ssim_loss': tensor(0.1544, device='cuda:0', grad_fn=<RsubBackward1>), 'l1_loss': tensor(1.1122e-05, device='cuda:0', grad_fn=<L1LossBackward>)}
torch.Size([1, 15, 640, 372, 2])
{'train_loss': tensor(0.1082, device='cuda:0', grad_fn=<AddBackward0>), 'ssim_loss': tensor(0.1082, device='cuda:0', grad_fn=<RsubBackward1>), 'l1_loss': tensor(4.7244e-06, device='cuda:0', grad_fn=<L1LossBackward>)}
torch.Size([1, 15, 640, 372, 2])
{'train_loss': tensor(0.2167, device='cuda:0', grad_fn=<AddBackward0>), 'ssim_loss': tensor(0.2167, device='cuda:0', grad_fn=<RsubBackward1>), 'l1_loss': tensor(4.8453e-06, device='cuda:0', grad_fn=<L1LossBackward>)}
torch.Size([1, 15, 372, 640, 2])
{'train_loss': tensor(0.1591, device='cuda:0', grad_fn=<AddBackward0>), 'ssim_loss': tensor(0.1591, device='cuda:0', grad_fn=<RsubBackward1>), 'l1_loss': tensor(8.6388e-06, device='cuda:0', grad_fn=<L1LossBackward>)}
torch.Size([1, 15, 322, 640, 2])
{'train_loss': tensor(nan, device='cuda:0', grad_fn=<AddBackward0>), 'ssim_loss': tensor(nan, device='cuda:0', grad_fn=<RsubBackward1>), 'l1_loss': tensor(nan, device='cuda:0', grad_fn=<L1LossBackward>)}
2021-05-21 21:40:43,880 | NaN or Inf found in input tensor.
2021-05-21 21:40:43,882 | NaN or Inf found in input tensor.
2021-05-21 21:40:43,882 | NaN or Inf found in input tensor.
2021-05-21 21:40:43,882 | NaN or Inf found in input tensor.
2021-05-21 21:40:43,882 | NaN or Inf found in input tensor.
2021-05-21 21:40:43,882 | NaN or Inf found in input tensor.

As you can see, I'm printing the batch size (as seen as torch.Size([1, 15, ***, 640, 2])) in this line. Also I print the loss_dict here in the code.

I executed the experiment multiple times and I always get the NaN or Inf found in input tensor. issue after the torch.Size([1, 15, ***, 640, 2]) changed.

BATCH_SIZE = 2

Same experiment as described above, just executed with batch_size = 2:

Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 313, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/user/venvs-builds/3.8/repo/fastmri/spawn_dist.py", line 32, in work
    run_task(args)
  File "/home/user/venvs-builds/3.8/repo/fastmri/run.py", line 34, in run
    trainer.train()
  File "/home/user/venvs-builds/3.8/repo/fastmri/base_trainer.py", line 206, in train
    self.run(epoch)
  File "/home/user/venvs-builds/3.8/repo/fastmri/training_loop_mixin.py", line 92, in run
    for batch_idx, batch in enumerate(self.train_loader):
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
    data = self._next_data()
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data
    return self._process_data(data)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data
    data.reraise()
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/_utils.py", line 428, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 73, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 73, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [15, 640, 372, 2] at entry 0 and [15, 322, 640, 2] at entry 1

Both have the same underlaying issue that the input size varies.

adefazio commented 3 years ago

For the NaN issue, try adding "torch.autograd.detect_anomaly" context manager to the code, it will give you a better idea where the NaN is coming from. You may have to add some other lines with isnan checks to see where the NaN might be coming from. Does training work when running with 1 GPU? (run.run(config) instead)?

tobiasvitt commented 3 years ago

Ok, I added 'nan_detection': True to the config in scripts/pretrain.py since this will set autograd.set_detect_anomaly(True) in the training_loop_mixin.py here.

The execution with run.run(config) leads to the same error described above, here the Traceback of the anomaly detection:

[W python_anomaly_mode.cpp:104] Warning: Error detected in DivBackward0. Traceback of forward call that caused the error:
  File "scripts/pretrain.py", line 72, in <module>
    run.run(args)  # Single GPU training (for debugging)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/run.py", line 34, in run
    trainer.train()
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/base_trainer.py", line 206, in train
    self.run(epoch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 150, in run
    loss, loss_dict = self.optimizer.step(closure=closure)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/optim/adam.py", line 66, in step
    loss = closure()
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 149, in <lambda>
    closure = lambda: batch_closure(batch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 114, in batch_closure
    result = self.training_loss(subbatch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/ssim_loss_mixin.py", line 59, in training_loss
    ssim_loss = 1 - self.ssim(output_, target_, data_range=max_value)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/ssim_loss_mixin.py", line 40, in forward
    S = (A1 * A2) / D
 (function _print_stack)
2021-05-21 23:58:52,931 | Uncaught exception
Traceback (most recent call last):
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/run.py", line 34, in run
    trainer.train()
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/base_trainer.py", line 206, in train
    self.run(epoch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 150, in run
    loss, loss_dict = self.optimizer.step(closure=closure)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/optim/adam.py", line 66, in step
    loss = closure()
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 149, in <lambda>
    closure = lambda: batch_closure(batch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 143, in batch_closure
    self.backwards(loss)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/base_trainer.py", line 222, in backwards
    loss.backward()
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'DivBackward0' returned nan values in its 0th output.
Traceback (most recent call last):
  File "scripts/pretrain.py", line 72, in <module>
    run.run(args)  # Single GPU training (for debugging)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/run.py", line 34, in run
    trainer.train()
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/base_trainer.py", line 206, in train
    self.run(epoch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 150, in run
    loss, loss_dict = self.optimizer.step(closure=closure)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/optim/adam.py", line 66, in step
    loss = closure()
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 149, in <lambda>
    closure = lambda: batch_closure(batch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 143, in batch_closure
    self.backwards(loss)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/base_trainer.py", line 222, in backwards
    loss.backward()
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'DivBackward0' returned nan values in its 0th output.

The result with spawn_dist.run(config):

Traceback (most recent call last):
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/run.py", line 34, in run
    trainer.train()
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/base_trainer.py", line 206, in train
    self.run(epoch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 150, in run
    loss, loss_dict = self.optimizer.step(closure=closure)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/optim/adam.py", line 66, in step
    loss = closure()
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 149, in <lambda>
    closure = lambda: batch_closure(batch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 143, in batch_closure
    self.backwards(loss)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/base_trainer.py", line 222, in backwards
    loss.backward()
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'DivBackward0' returned nan values in its 0th output.
/home/user/venvs-builds/3.8/task_repository/repo/fastmri/data/transforms.py:146: UserWarning: The function torch.ifft is deprecated and will be removed in a future PyTorch release. Use the new torch.fft module functions, instead, by importing torch.fft and calling torch.fft.ifft or torch.fft.ifftn. (Triggered internally at  /pytorch/aten/src/ATen/native/SpectralOps.cpp:578.)
  data = torch.ifft(data, 2, normalized=True)
/home/user/venvs-builds/3.8/task_repository/repo/fastmri/data/transforms.py:127: UserWarning: The function torch.fft is deprecated and will be removed in PyTorch 1.8. Use the new torch.fft module functions, instead, by importing torch.fft and calling torch.fft.fft or torch.fft.fftn. (Triggered internally at  /pytorch/aten/src/ATen/native/SpectralOps.cpp:567.)
  data = torch.fft(data, 2, normalized=True)
/home/user/venvs-builds/3.8/task_repository/repo/fastmri/ssim_loss_mixin.py:60: UserWarning: This overload of add is deprecated:
    add(Number alpha, Tensor other)
Consider using one of the following signatures instead:
    add(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:882.)
  loss = ssim_loss.add(self.ssim_l1_coefficient, l1_loss)
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
Process Process-2:
    self.run()
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/tensorboardX/event_file_writer.py", line 202, in run
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 313, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/spawn_dist.py", line 32, in work
    run_task(args)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/run.py", line 34, in run
    trainer.train()
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/base_trainer.py", line 206, in train
    self.run(epoch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 150, in run
    loss, loss_dict = self.optimizer.step(closure=closure)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/optim/adam.py", line 66, in step
    loss = closure()
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 149, in <lambda>
    closure = lambda: batch_closure(batch)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/training_loop_mixin.py", line 143, in batch_closure
    self.backwards(loss)
  File "/home/user/venvs-builds/3.8/task_repository/repo/fastmri/base_trainer.py", line 222, in backwards
    loss.backward()
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/user/venvs-builds/3.8/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
    Variable._execution_engine.run_backward(
RuntimeError: Function 'DivBackward0' returned nan values in its 0th output.

Running the experiment on single gpu or distributed, both resulting in the same issue. Is it actually intended that the input has different shapes?
My batch["input"] can be anything like torch.Size([1, 15, X, 640, 2]) where X in [322, 338, 368, 370, 372, 388, ...]. I still believe that this is the current underlaying issue.

adefazio commented 3 years ago

Yes, the differing input sizes was intentional. It was a design decision we made, rather than cropping or padding which can be problematic when working with mixed Fourier/spatial domains, we decided to just have the model directly support differing input sizes.

adefazio commented 3 years ago

It's not clear to me what's causing the problem. It's possible that it may work when using a smaller learning rate. Since you are using fewer GPUs then we did, a smaller LR may be required to stabilize training. You could also try a larger momentum value, that tends to result in greater stability (momentum 0.98 maybe?)

Increasing ssim_l1_coefficient might help as well. Unfortunately this sort of adversarial training is extremely fickle and diverges easily.

tobiasvitt commented 3 years ago

I will try a different momentum, _ssim_l1coefficient and lr when I am able to train the model.

I'm quite certain that the issue described above is not caused by model instability, gpu setup or hyperparameters.

After doing some more debugging I could trace the issue down to the MaskCenter function. The training runs well for a few batches with input size [15, 372, 640, 2] and fails with the first batch of size [15, 322, 640, 2].

The NaNs are produced inside the Norm(nn.Module).forward() function because the mean and std are zero in x = (x - mean) / std. The norm(x, model, norm_type, norm_mean, norm_std) function was called with type = 'group' and mean and std set as true. The mean and std were zero because the norm function input x only consisted out of zeros.

After finding the reason for the NaN production, I followed the x = zeros issue. The before "norm" called functions Pad(nn.Module).forward() and Fm2Batch(nn.Module).forward() also got x = zeros as input. Those zeros are produced in the MaskCenter(nn.Module).forward() function. As far as I understand the MaskCenter(nn.Module).forward() function is responsible for the central band of 16 lowest-frequency k-space lines. Looking at the values of each variable in MaskCenter(nn.Module).forward() and the internal called mask_center(x, num_lf) suggests that the mask consists purely of zeros because the choosen section of x[:,:,:,312:328]only contains zeros. Using x.detach().unique() tells us that x in general has a lot of values that are not zero, still the chosen section contains only zeros as seen below.

Below you can see both functions with the respective variable values at the time of execution. The variable values are shown as {..} after the respective code lines.

Execution resulting in zeros:

class MaskCenter(nn.Module):
    def forward(self, x, input):
        s = x.size(2) {s = 1}
        mask = torch.zeros_like(x) {mask.shape = [1, 15, 1, 322, 640, 2]}
        for j in range(s):
            lf = input['num_lf'][j] {16}
            mask[:, :, j, ...] = T.mask_center(x[:, :, j, ...], lf)
        return mask {mask.detach().unique() = [0.]}
def mask_center(x, num_lf): {num_lf = 16}
    b, c, h, w, two = x.shape  {b = 1, c = 15, h = 322, w = 640, two = 2}
    mask = torch.zeros_like(x) {mask.shape = [1, 15, 322, 640, 2]}
    pad = (w - num_lf + 1) // 2 {pad = 312, num_lf = 16}
    mask[:, :, :, pad:pad + num_lf] = x[:, :, :, pad:pad + num_lf] {x[:, :, :, pad:pad + num_lf].detach().unique() = [0.]}
    return mask  {mask.detach().unique() = [0.]}

Execution resulting in propper mask :

class MaskCenter(nn.Module):
    def forward(self, x, input):
        s = x.size(2) {s = 1}
        mask = torch.zeros_like(x) {mask.shape = [1, 15, 1, 372, 640, 2]}
        for j in range(s):
            lf = input['num_lf'][j] {16}
            mask[:, :, j, ...] = T.mask_center(x[:, :, j, ...], lf)
        return mask {mask.detach().unique() = [-0.0051, -0.0044, -0.0033,  ...,  0.0027,  0.0039,  0.0043]}
def mask_center(x, num_lf): {num_lf = 16}
    b, c, h, w, two = x.shape  {b = 1, c = 15, h = 372, w = 640, two = 2}
    mask = torch.zeros_like(x) {mask.shape = [1, 15, 372, 640, 2]}
    pad = (w - num_lf + 1) // 2 {pad = 312, num_lf = 16}
    mask[:, :, :, pad:pad + num_lf] = x[:, :, :, pad:pad + num_lf] {x[:, :, :, pad:pad + num_lf].detach().unique() = tensor([-0.0051, -0.0044, -0.0033,  ...,  0.0027,  0.0039,  0.0043])}
    return mask  {mask.detach().unique() = tensor([-0.0051, -0.0044, -0.0033,  ...,  0.0027,  0.0039,  0.0043])}

I don't really understand why this happens... @adefazio can you provide me with the package versions used by you? I in general get a few warnings that functions used are deprecated and I just want to make sure that those version inconsistencies can't be the reason for my issue.

adefazio commented 3 years ago

This is very bizarre! Thanks for tracking down this issue further. If that's the only issue occurring, you may just be able to to fix it by adding an epsilon like 1e-5 to the std variable, that's what the official group norm implementation does in PyTorch. As to why this happens, I don't know. I wish I could provide an exact set of packages used for the paper but unfortunately I didn't keep track of then when I ran the experiments. Have you tried running the main fastmri implementation in this repository? It uses almost the same varnet model with the same normalization layer I believe.

tobiasvitt commented 3 years ago

Thank you for your response. For now I stopped debugging this issue, but I will try adding the epsilon. The last few days I trained the E2E VarNet from the main repository on the single and multi coil dataset with decent reconstruction results. Hence, I assume the dataset itself can't be the issue. If I find the problem in the banding-removal codebase, I will comment here.

mmuckley commented 3 years ago

Closing due to inactivity. Reopen if you'd like to continue the thread.