Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.5k stars 1.01k forks source link

Cannot run V-Net on medical decathlon data #7852

Open linnabraham opened 1 week ago

linnabraham commented 1 week ago

Describe the bug PyTorch complains of size mismatch when using V-Net with medical decathlon data.

To Reproduce

import monai
from monai.apps import DecathlonDataset
from monai.transforms import LoadImaged, EnsureChannelFirstd,ScaleIntensityd, ToTensord, Compose
from monai.networks.nets import VNet
from monai.losses.dice import DiceLoss
import torch

def train_one_epoch(train_loader, loss_fn, optimizer, epoch):
    running_loss = 0.
    example_ct = 0

    for batch_idx, dict_item in enumerate(train_loader):
        images = dict_item['image']
        labels = dict_item['label']
        print("Shape of images", images.shape)
        print("Shape of labels", labels.shape)
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs,labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
        example_ct += images.size(0)
        metrics = {"train/train_loss": loss.item(),
                   "train/epoch": epoch,
                    "train/example_ct": example_ct
                   }
        print(metrics)
    return running_loss/example_ct

def train_loop(train_loader, val_loader):
    loss_fn = DiceLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    best_vloss = 1_000_000.
    for epoch in range(3):
        print(f"Epoch:{epoch+1}")
        model.train()
        avg_train_loss = train_one_epoch(train_loader, loss_fn, optimizer, epoch)
        print("train loss", avg_train_loss)

if __name__=="__main__":

    transform = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            ScaleIntensityd(keys="image"),
            ToTensord(keys=["image", "label"]),
        ]
    )

    train_data = DecathlonDataset(
        root_dir="./", task="Task04_Hippocampus", transform=transform, section="validation", seed=12345, download=False
    )
    model = VNet(spatial_dims=3, in_channels=1, out_channels=1, act='elu')
    train_loader = monai.data.DataLoader(
        train_data, batch_size=1, num_workers=2, persistent_workers=True)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device {device}")

    model.to(device)

    train_loop(train_loader, val_loader=None)

Expected behavior Training happens

Screenshots

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 5 for tensor number 1 in the list.

Complete Traceback

Traceback (most recent call last):
  File "/home/linn/vnet/train.py", line 65, in <module>
    train_loop(train_loader, val_loader=None)
  File "/home/linn/vnet/train.py", line 39, in train_loop
    avg_train_loss = train_one_epoch(train_loader, loss_fn, optimizer, epoch)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linn/vnet/train.py", line 19, in train_one_epoch
    outputs = model(images)
              ^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/networks/nets/vnet.py", line 274, in forward
    x = self.up_tr256(out256, out128)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/networks/nets/vnet.py", line 165, in forward
    xcat = torch.cat((out, skipxdo), 1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/data/meta_tensor.py", line 282, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/linn/miniconda3/envs/monoai/lib/python3.11/site-packages/torch/_tensor.py", line 1443, in __torch_function__
    ret = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^

Environment

Ensuring you use the relevant python executable, please paste the output of:

python -c "import monai; monai.config.print_debug_info()"

================================
Printing MONAI config...
================================
MONAI version: 1.3.1
Numpy version: 1.26.4
Pytorch version: 2.3.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 96bfda00c6bd290297f5e3514ea227c6be4d08b4
MONAI __file__: /data/<username>/miniconda3/envs/monoai/lib/python3.11/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: NOT INSTALLED or UNKNOWN VERSION.
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.66.4
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

================================
Printing system config...
================================
`psutil` required for `print_system_info`

================================
Printing GPU config...
================================
Num GPUs: 1
Has CUDA: True
CUDA version: 12.1
cuDNN enabled: True
NVIDIA_TF32_OVERRIDE: None
TORCH_ALLOW_TF32_CUBLAS_OVERRIDE: None
cuDNN version: 8902
Current device: 0
Library compiled for CUDA architectures: ['sm_50', 'sm_60', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90']
GPU 0 Name: NVIDIA A100-PCIE-40GB
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 108
GPU 0 Total memory (GB): 39.4
GPU 0 CUDA capability (maj.min): 8.0

**Additional context**
Add any other context about the problem here.
KumoLiu commented 3 days ago

Hi @linnabraham, looks like a shape mismatch issue. Did you try to check your input data shape before sending to the model?