Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.87k stars 1.09k forks source link

Cannot run V-Net on medical decathlon data #7852

Closed linnabraham closed 3 months ago

linnabraham commented 5 months ago

Describe the bug PyTorch complains of size mismatch when using V-Net with medical decathlon data.

To Reproduce

import monai
from monai.apps import DecathlonDataset
from monai.transforms import LoadImaged, EnsureChannelFirstd,ScaleIntensityd, ToTensord, Compose
from monai.networks.nets import VNet
from monai.losses.dice import DiceLoss
import torch

def train_one_epoch(train_loader, loss_fn, optimizer, epoch):
    running_loss = 0.
    example_ct = 0

    for batch_idx, dict_item in enumerate(train_loader):
        images = dict_item['image']
        labels = dict_item['label']
        print("Shape of images", images.shape)
        print("Shape of labels", labels.shape)
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs,labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        example_ct += images.size(0)
        metrics = {"train/train_loss": loss.item(),
                   "train/epoch": epoch,
                    "train/example_ct": example_ct
                   }
        print(metrics)
    return running_loss/example_ct

def train_loop(train_loader, val_loader):
    loss_fn = DiceLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    best_vloss = 1_000_000.
    for epoch in range(3):
        print(f"Epoch:{epoch+1}")
        model.train()
        avg_train_loss = train_one_epoch(train_loader, loss_fn, optimizer, epoch)
        print("train loss", avg_train_loss)

if __name__=="__main__":

    transform = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    train_data = DecathlonDataset(
        root_dir="./", task="Task04_Hippocampus", transform=transform, section="validation", seed=12345, download=False
    )
    model = VNet(spatial_dims=3, in_channels=1, out_channels=1, act='elu')
    train_loader = monai.data.DataLoader(
        train_data, batch_size=1, num_workers=2, persistent_workers=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device {device}")

    model.to(device)

    train_loop(train_loader, val_loader=None)

Expected behavior Training happens

Screenshots

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 5 for tensor number 1 in the list.

Complete Traceback

Traceback (most recent call last):
  File "/home/linn/vnet/train.py", line 65, in <module>
    train_loop(train_loader, val_loader=None)
  File "/home/linn/vnet/train.py", line 39, in train_loop
    avg_train_loss = train_one_epoch(train_loader, loss_fn, optimizer, epoch)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linn/vnet/train.py", line 19, in train_one_epoch
    outputs = model(images)
              ^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/networks/nets/vnet.py", line 274, in forward
    x = self.up_tr256(out256, out128)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/networks/nets/vnet.py", line 165, in forward
    xcat = torch.cat((out, skipxdo), 1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/data/meta_tensor.py", line 282, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/_tensor.py", line 1443, in __torch_function__
    ret = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^

Environment

Ensuring you use the relevant python executable, please paste the output of:

python -c "import monai; monai.config.print_debug_info()"

================================
Printing MONAI config...
================================
MONAI version: 1.3.1
Numpy version: 1.26.4
Pytorch version: 2.3.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 96bfda00c6bd290297f5e3514ea227c6be4d08b4
MONAI __file__: /data/<username>/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: NOT INSTALLED or UNKNOWN VERSION.
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.4
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...
================================
`psutil` required for `print_system_info`

================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 12.1
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None
cuDNN version: 8902
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA A100-PCIE-40GB
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 108
GPU 0 Total memory (GB): 39.4
GPU 0 CUDA capability (maj.min): 8.0

**Additional context**
Add any other context about the problem here.
KumoLiu commented 4 months ago

Hi @linnabraham, looks like a shape mismatch issue. Did you try to check your input data shape before sending to the model?

linnabraham commented 4 months ago

@KumoLiu I did now and it seems like the decathlon data shape is not compatible with the V-Net. I was not expecting that since I had earlier used the same data with tensorflow implementation of V-Net (https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/Segmentation/VNet).

But I found an issue with the V-Net implementation. It seems like the out_channels is hard coded as 16. Which implied that the in_channels could only be 16 or 1. I have re-opened an bug report that was closed without the proper fix here https://github.com/Project-MONAI/MONAI/issues/4896

KumoLiu commented 4 months ago

Which implied that the in_channels could only be 16 or 1. I have re-opened an bug report that was closed without the proper fix here #4896

The in_channels can be multiples of 16. https://github.com/Project-MONAI/MONAI/blob/64ea76d83a92b7cf7f13c8f93498d50037c3324c/monai/networks/nets/vnet.py#L211 In Tensorflow, they also set out_channels as 16: https://github.com/NVIDIA/DeepLearningExamples/blob/729963dd47e7c8bd462ad10bfac7a7b0b604e6dd/TensorFlow/Segmentation/VNet/model/vnet.py#L34

linnabraham commented 4 months ago

Thanks for pointing out the tensorflow code. But I am still confused. My input has shape (64, 128, 128). Right now I edited the source code to remove 16 from being hard coded, but no matter what I give as out_channel, 1, 16, 64, 128, I am getting a shape mismatch error. What do I do?

KumoLiu commented 4 months ago

If your shape is (64, 128, 128), then your spatial_dims should be 2 since 64 is the channel dim.

linnabraham commented 4 months ago

Thanks for pointing that out. I set it to 2. I could not use 16 as out channels, so I tried 64 itself. Now I get this error

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [32, 64, 128, 128, 1]
kvttt commented 3 months ago

Hi @linnabraham, can you try removing EnsureChannelFirstd(keys=["image", "label"]), such that transform is given by

transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        ScaleIntensityd(keys="image"),
        ToTensord(keys=["image", "label"]),
    ]
)

because it looks like EnsureChannelFirstd adds an extra singleton dimension to the image, making it 3D instead of 2D.