Project-MONAI / tutorials

MONAI Tutorials
https://monai.io/started.html
Apache License 2.0
1.77k stars 667 forks source link

Issue with dimensions on output from the "Spleen 3D tutorial" - outputs "Multi-Component" images in ITK-snap #1005

Closed carlpe closed 1 year ago

carlpe commented 1 year ago

Greetings,

I am training a dataset of 2 classes (plus background) on a code based on the "Spleen 3D segmentation". Training seems to be working well, but after running inference, I get output dimensions that are incorrect. When I open the output in either ITK-Snap or Slicer, it says that these nifti outputs are "Multi-Component"

It seems like the issue is related to either the "to_onehot=x" from Asdiscreted transform or it might be related to the the "squeeze_end_dims=True" from SaveImaged transform.

I have seen some similar issues here on Monai Github: [https://github.com/Project-MONAI/tutorials/issues/433] [https://github.com/Project-MONAI/MONAI/issues/1677]

As we can see in ITK-snap, there are 3 "image-components" to choose from on the top-left corner of each projection frame:

image image

I am using the latest Docker from monai:latest

MONAI version: 1.0.0+53.gcd2f4e15
Numpy version: 1.22.2
Pytorch version: 1.13.0a0+d0d6b1f
MONAI flags: HAS_EXT = True, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: cd2f4e15a050f2ff1abc66bf19d499ce3cd58a31
MONAI __file__: /opt/monai/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.10
Nibabel version: 4.0.2
scikit-image version: 0.19.3
Pillow version: 9.0.1
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 4.5.2
TorchVision version: 0.14.0a0
tqdm version: 4.64.1
lmdb version: 1.3.0
psutil version: 5.9.2
pandas version: 1.4.4
einops version: 0.5.0
transformers version: 4.21.3
mlflow version: 1.29.0
pynrrd version: 1.0.0

I have one .py script for training and one for inference

Training:

from monai.transforms import (
    AsDiscreted,
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    SaveImaged,
    Invertd,
)
from monai.losses.dice import DiceCELoss
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
import torch
import os, warnings, monai
import numpy as np
import nibabel as nib
from os import listdir
from os.path import isfile, join
import matplotlib.pyplot as plt
import nibabel as nib
import os
import glob
from monai.optimizers import Novograd
from monai.metrics import compute_meandice, DiceMetric
from monai.utils import first

CUDA_LAUNCH_BLOCKING = "1"
data_root = "/monai/2"
path_to_target = "/data/norm/"

train_images = sorted(glob.glob(os.path.join(path_to_target, ".", "*.nii")))
train_labels = sorted(glob.glob(os.path.join(path_to_target, ".", "*.nii.gz")))

data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files, val_files = data_dicts[:300], data_dicts[300:350]
print(len(train_files))
print(len(val_files))

train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(320, 320, 320),
            pos=1,
            neg=1,
            num_samples=1,
            image_key="image",
            image_threshold=0,
        ),
        RandAdjustContrastd(keys="image", prob=0.3),
        RandGaussianSmoothd(keys="image", prob=0.3),
        RandRotated(
            keys=["image", "label"], range_x=0.3, range_y=0.4, range_z=0.4, prob=0.4
        ),
        RandFlipd(keys=["image", "label"], prob=0.4, spatial_axis=[0, 1, 2]),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.3),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        CropForegroundd(keys=["image", "label"], source_key="image"),
    ]
)

check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")

train_ds = CacheDataset(
    data=train_files, transform=train_transforms, cache_rate=0.2, num_workers=1
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=1)
val_ds = CacheDataset(
    data=val_files, transform=val_transforms, cache_rate=0.2, num_workers=1
)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1)

device = torch.device("cuda:0")

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

loss_function = DiceLoss(to_onehot_y=True, softmax=True)
model.load_state_dict(torch.load("plavo.pth"))
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

max_epochs = 600
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=3)])
post_label = Compose([AsDiscrete(to_onehot=3)])

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}, "
            f"Train_loss: {loss.item():.8f}"
        )
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"Epoch {epoch + 1} Average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (320, 320, 320)
                sw_batch_size = 1
                val_outputs = sliding_window_inference(
                    val_inputs, roi_size, sw_batch_size, model
                )
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(data_root, "plavo.pth"))
                print("Saved new best metric model")
            print(
                f"Current epoch: {epoch + 1} Current mean dice: {metric:.4f}"
                f"\nBest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )

print(
    f"Training completed, Best_metric: {best_metric:.4f} "
    f"at epoch: {best_metric_epoch}"
)

And the code for inference:

from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandAdjustContrastd,
    RandGaussianSmoothd,
    RandRotated,
    RandFlipd,
    RandScaleIntensityd,
)
from monai.losses.dice import DiceCELoss
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
import torch
import matplotlib.pyplot as plt
import os
import glob
from monai.optimizers import Novograd
from monai.metrics import DiceMetric
from monai.utils import set_determinism

device = torch.device("cpu")
root_dir = "/monai/2"
data_dir = "/data/norm/test"
test_images = sorted(glob.glob(os.path.join(data_dir, ".", "*.nii")))
test_data = [{"image": image} for image in test_images]

test_org_transforms = Compose(
    [
        LoadImaged(keys="image"),
        EnsureChannelFirstd(keys="image"),
        Orientationd(keys=["image"], axcodes="RAS"),
        CropForegroundd(keys=["image"], source_key="image"),
    ]
)

check_ds = Dataset(data=test_data, transform=test_org_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image = (check_data["image"][0][0])
print(f"image shape: {image.shape}")

test_org_ds = Dataset(data=test_data, transform=test_org_transforms)
test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=1)

post_transforms = Compose(
    [
        AsDiscreted(keys="pred", argmax=True, to_onehot=3),
        SaveImaged(
            keys="pred",
            # meta_keys="pred_meta_dict",
            output_dir="./out",
            output_postfix="seg",
            separate_folder=False,
            resample=False,
            squeeze_end_dims=True,
        ),
    ]
)

n1 = 32
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

loss_function = DiceLoss(to_onehot_y=True, softmax=True)
model.load_state_dict(torch.load(os.path.join(root_dir, "plavo.pth")))
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

model.eval()
with torch.no_grad():
    for test_data in test_org_loader:
        test_inputs = test_data["image"].to(device)
        roi_size = (320,320,320)
        sw_batch_size = 1
        test_data["pred"] = sliding_window_inference(
            test_inputs, roi_size, sw_batch_size, model)

        test_data = [post_transforms(i) for i in decollate_batch(test_data)]

I hope to understand why this is happening, it must be more users with the similar issue?

KumoLiu commented 1 year ago

Hi @carlpe, you don't need to one-hot your output before SaveImage, it will turn your segmentation into 3 channels, you can simply do it like this:

post_trans = Compose([Activations(softmax=True), AsDiscrete(argmax=True)])
saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg")

Thanks!