deepinv / deepinv

PyTorch library for solving imaging inverse problems using deep learning
https://deepinv.github.io
BSD 3-Clause "New" or "Revised" License
282 stars 64 forks source link

Batch not working for ADMM deblurring #219

Closed Tmodrzyk closed 4 months ago

Tmodrzyk commented 4 months ago

Hello, I've been trying to set up a benchmark of all your methods for the deblurring operation.

However I noticed that when I want to use a batch size > 1, I get errors for some iterators. So far I noticed PGD and FISTA handle batch size > 1, but ADMM and DRS don't. They work for a batch size = 1 though.

Is this the expected behavior ? Looking at the documentation I thought any batch size would work for any method.

I did not test with other physics, maybe is it related to the blur operator specifically. Here is a minimal example of the issue:

import deepinv as dinv
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from deepinv.utils.demo import load_dataset
from deepinv.optim.data_fidelity import L2
from deepinv.optim.optimizers import optim_builder
from deepinv.training.testing import test

BASE_DIR = Path(".")
ORIGINAL_DATA_DIR = BASE_DIR / "datasets"
DATA_DIR = BASE_DIR / "measurements"
torch.manual_seed(0)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
img_size = 256 if torch.cuda.is_available() else 32
num_workers = 4 if torch.cuda.is_available() else 0

# Load dataset
val_transform = transforms.Compose([transforms.CenterCrop(img_size), transforms.ToTensor()])
dataset = load_dataset("set3c", ORIGINAL_DATA_DIR, transform=val_transform)

# Define physics model
physics = dinv.physics.Blur(
    filter=dinv.physics.blur.gaussian_blur(sigma=3.0, angle=0.0),
    n_channels=3, 
    device=device, 
    padding="circular",
    noise_model=dinv.physics.GaussianNoise(sigma=0.02)
)

# Generate measurement dataset
measurement_dir = DATA_DIR / "set3c" / "deblur"
dinv_dataset_path = dinv.datasets.generate_dataset(
    train_dataset=None, 
    test_dataset=dataset, 
    physics=physics,
    device=device, 
    save_dir=measurement_dir, 
    num_workers=num_workers
)

# Define optimizer model
modelADMM_TV = optim_builder(
    iteration="ADMM", 
    prior=dinv.optim.prior.TVPrior(n_it_max=20),
    data_fidelity=L2(), 
    early_stop=True, 
    max_iter=100, 
    verbose=True,
    params_algo={"stepsize": 1.0, "lambda": 1e-2}
)

# Create DataLoader and run the test
dataloader = DataLoader(dinv.datasets.HDF5Dataset(path=dinv_dataset_path, train=False), 
                        batch_size=3, num_workers=num_workers, shuffle=False)

test(
    model=modelADMM_TV, 
    test_dataloader=dataloader, 
    physics=physics,
    metrics=[dinv.loss.PSNR()], 
    device=device, 
    plot_images=True,
    wandb_vis=False, 
    plot_only_first_batch=False
)

Which produces the following output:

Selected GPU 0 with 7686 MB free memory 
Dataset has been saved in measurements/set3c/deblur
Test :   0%|                                                                                                                    | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/modrzyk/code/benchmark/minimal_example.py", line 58, in <module>
    test(
  File "/home/modrzyk/miniconda3/envs/deepinv/lib/python3.12/site-packages/deepinv/training/testing.py", line 144, in test
    x_net = model(y, physics_cur)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/miniconda3/envs/deepinv/lib/python3.12/site-packages/deepinv/optim/optimizers.py", line 478, in forward
    X, metrics = self.fixed_point(
                 ^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/miniconda3/envs/deepinv/lib/python3.12/site-packages/deepinv/optim/fixed_point.py", line 241, in forward
    X = self.iterator(X_prev, cur_data_fidelity, cur_prior, cur_params, *args)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/miniconda3/envs/deepinv/lib/python3.12/site-packages/deepinv/optim/optim_iterators/admm.py", line 57, in forward
    u = self.f_step(x, z, cur_data_fidelity, cur_params, y, physics)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/miniconda3/envs/deepinv/lib/python3.12/site-packages/deepinv/optim/optim_iterators/admm.py", line 91, in forward
    return cur_data_fidelity.prox(p, y, physics, gamma=cur_params["stepsize"])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/miniconda3/envs/deepinv/lib/python3.12/site-packages/deepinv/optim/data_fidelity.py", line 323, in prox
    return physics.prox_l2(x, y, self.norm * gamma)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/miniconda3/envs/deepinv/lib/python3.12/site-packages/deepinv/physics/forward.py", line 551, in prox_l2
    x = conjugate_gradient(H, b, self.max_iter, self.tol)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/modrzyk/miniconda3/envs/deepinv/lib/python3.12/site-packages/deepinv/optim/utils.py", line 76, in conjugate_gradient
    x = x + p * alpha
            ~~^~~~~~~
RuntimeError: The size of tensor a (256) must match the size of tensor b (3) at non-singleton dimension 3

I have the most recent verson of deepinv installed.

Tmodrzyk commented 4 months ago

Update : it seems that using BlurFFT instead of Blur solves the issue.

tachella commented 4 months ago

Hi Thibault,

Many thanks for pointing this out - I think the problem seems to be related to the conjugate_gradient solver which is only called for optimization schemes that require the proximal operator of the data fidelity term which is not available in closed form for the blur operator, but it is for the blurFFT one.

Will investigate further why conjugate_gradient is not working as expected here.

tachella commented 4 months ago

The problem has been fixed in #221

Tmodrzyk commented 3 months ago

Many thanks for your reactivity.