Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.76k stars 1.06k forks source link

2D Segmentation Example #920

Closed jillianlee closed 4 years ago

jillianlee commented 4 years ago

Hi There,

Do you plan on posting examples or workbooks for any 2D segmentation problems?

I'm having some issues trying to get predictions to work with 2D input data. The same data given in 3D to the same network works just fine, but when using the 2D data, all my masks come out black. I've tried changing quite a few of the parameters and optimizer settings but nothings helped. I think an example may be helpful if you have time for that!

Thanks Jillian

Nic-Ma commented 4 years ago

Hi @jillianlee ,

Thanks for your interest here. For 2D segmentation task, it should be just slightly different from 3D segmentation at network dims and transform shapes. Could you please paste your test program and let me check some details?

Thanks.

jillianlee commented 4 years ago

Hi @Nic-Ma

I assumed as such, and I'm not getting any errors, it just seems like a vanishing gradient problem but the usual solutions are not helping. I've tried changing activation functions in the UNet, loss functions, optimizers and nothing has helped.

class ConvertLabel(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WC (Whole tumor)*** we want WC because flair shows whole tumor
    and ET (Enhancing tumor).

    """
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            # merge labels 1, 2 and 3 to construct WC
            result = np.logical_or(np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1)
            d[key] = result.astype(np.float32)
        return d

monai.config.print_config()
set_determinism(seed=0)
tempdir = tempfile.mkdtemp()

text_t1 = open(r'C:\Users\iBest\Documents\Jillian\Documents\BO-Aug-master\dataset\BraTS\filename_flair.txt', 'r') 
train_images = text_t1.read().split('\n')

text_segs = open(r'C:\Users\iBest\Documents\Jillian\Documents\BO-Aug-master\dataset\BraTS\filename_seg.txt', 'r') 
train_labels = text_segs.read().split('\n')

data_dicts = [{'image': image_name, 'label': label_name}
              for image_name, label_name in zip(train_images, train_labels)]

transforms = Compose([
LoadNiftid(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']),
ConvertLabel(keys='label'),
#Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), mode=('bilinear', 'nearest')),
Orientationd(keys=['image', 'label'], axcodes='RAS'),
Resized(keys=['image', 'label'], spatial_size = (160, 160, 72)),
RandSpatialCropSamplesd(keys=['image', 'label'], roi_size=[160, 160, 1], num_samples = 72, random_center = True, random_size = False),
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
SqueezeDimd(keys=['image', 'label'], dim=-1),
#ScaleIntensityd(keys='image', minv=0.0, maxv=1.0),
RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
ToNumpyd(keys=['image', 'label'])
])

transforms_val = Compose([
LoadNiftid(keys=['image', 'label']),
AddChanneld(keys=['image', 'label']),
ConvertLabel(keys='label'),
#Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 2.), mode=('bilinear', 'nearest')),
Orientationd(keys=['image', 'label'], axcodes='RAS'),
Resized(keys=['image', 'label'], spatial_size = (160, 160, 72)),
RandSpatialCropSamplesd(keys=['image', 'label'], roi_size=[160, 160, 1], num_samples = 72, random_center = True, random_size = False),
SqueezeDimd(keys=['image', 'label'], dim=-1),
#ScaleIntensityd(keys='image', minv=0.0, maxv=1.0),
ToNumpyd(keys=['image', 'label'])
])

train_files, val_files, test_files = data_dicts[0:171], data_dicts[171:228], data_dicts[228:285]

train_ds = monai.data.ArrayDataset(train_files, img_transform = transforms)
train_load = monai.data.DataLoader(train_ds)

val_ds = monai.data.ArrayDataset(val_files, img_transform = transforms_val)
val_load = monai.data.DataLoader(val_ds)

test_ds = monai.data.ArrayDataset(test_files, img_transform = transforms_val)
test_load = monai.data.DataLoader(test_ds)

flat_test = [item for sublist in test_load.dataset for item in sublist]
test_data = [sub['image'] for sub in flat_test]
test_data = [np.squeeze(im, 0) for im in test_data]
test_labels = [sub['label'] for sub in flat_test]
test_labels = [np.squeeze(lab, 0) for lab in test_labels]

flat_train = [item for sublist in train_load.dataset for item in sublist]
train_images = [sub['image'] for sub in flat_train]
train_images = [np.squeeze(im, 0) for im in train_images]
train_labels = [sub['label'] for sub in flat_train]
train_labels = [np.squeeze(lab, 0) for lab in train_labels]

flat_val = [item for sublist in val_load.dataset for item in sublist]
val_images = [sub['image'] for sub in flat_val]
val_images = [np.squeeze(im, 0) for im in val_images]
val_labels = [sub['label'] for sub in flat_val]
val_labels = [np.squeeze(lab, 0) for lab in val_labels]

dict_train = [{'image': image_data, 'label': label_data} 
                for image_data, label_data in zip(train_images, train_labels)]

dict_val = [{'image': image_data, 'label': label_data}
              for image_data, label_data in zip(val_images, val_labels)]
dict_test = [{'image': image_data, 'label': label_data}
              for image_data, label_data in zip(test_data, test_labels)]

transforms = Compose([
    AddChanneld(keys=['image', 'label']),
    ToNumpyd(keys=['image', 'label'])
    ])

train_data = monai.data.ArrayDataset(dict_train, img_transform = transforms)
train_loader = monai.data.DataLoader(train_data, batch_size=3, shuffle=True, num_workers=0, multiprocessing_context = None)

val_data = monai.data.ArrayDataset(dict_val, img_transform = transforms)
val_loader = monai.data.DataLoader(val_data, batch_size=1, shuffle=True, num_workers=0, multiprocessing_context = None)

test_data = monai.data.ArrayDataset(dict_test, img_transform = transforms)
test_loader = monai.data.DataLoader(test_data, batch_size=1, shuffle=True, num_workers=0, multiprocessing_context = None)

device = torch.device('cuda:0')
model = monai.networks.nets.UNet(dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64),
                                  strides=(2, 2), act = 'elu').to(device)
loss_function = monai.losses.DiceLoss(to_onehot_y=False, sigmoid=True)

optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, to_onehot_y=False, sigmoid = True, reduction="mean")

num_epochs = 100

# TRAINING
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()

writer = SummaryWriter()
for epoch in range(num_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{20}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        #inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        inputs, labels = batch_data['image'].to(device), batch_data['label'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs.type(torch.cuda.FloatTensor))
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_data) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        #writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            metric_sum = 0.0
            metric_count = 0
            val_images = None
            val_labels = None
            val_outputs = None
            for val_data in val_loader:
                #val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                val_images, val_labels = val_data['image'].to(device), val_data['label'].to(device)
                roi_size = (160, 160)
                sw_batch_size = 4
                #val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                val_outputs = model(val_images)
                value = dice_metric(y_pred=val_outputs, y=val_labels)
                metric_count += len(value)
                metric_sum += value.item() * len(value)

            metric = metric_sum / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            #writer.add_scalar("val_mean_dice", metric, epoch + 1)
            #plot the last model output as GIF image in TensorBoard with the corresponding image and label
            #shutil.rmtree(tempdir)

The network parameters, optimizer, and loss function seem to work fine with the 3D data fed from ArrayDataset. When I perform sampling, the images seem to look fine, but the segmentation masks come out black and the peak DSC is around 0.2.

I wondered if there was an issue with array management/sampling so I tried running a 2D UNet with the RandSpatialCropSamplesd transform, but this does cause an error with batch size in the validation set. The validation batch size of 1, combined with RandSpatialCropSamplesd creates batch size of 1xnum_samples which throws an error in the validation portion.

The 2D numpy array sampling portion is something I need to do for my own data augmentation experiments so unfortunately I cannot omit it.

Any advice would be appreciated :) Thanks Jillian

Nic-Ma commented 4 years ago

Hi @jillianlee ,

2 points I found here may be helpful:

  1. As your data is in dict format, please use monai.data.Dataset or monai.data.CacheDataset instead of ArrayDataset.
  2. Are you using the brain tumor dataset in Decathlon Task01 or original BraTS dataset? The labels are different in them. I remember BraTS uses label 1, 2, 4, you need to double confirm it.

Thanks.

jillianlee commented 4 years ago

Hi @Nic-Ma

  1. Oh right, my apologies. Previously I had used array data with the ArrayDataset, and I tried switching to dict type data to debug. I switched to Dataset and no improvement.

  2. I also just tried 1,2, and 4 for BraTS and no repair. The masks still all come out black

Essentially this entire code works perfectly fine with the 3D data, but as soon as its sliced into 2D images, it breaks. The 3D U-Net obtains a reasonable DSC around high 0.7ish.

The 2D code is doing something ~ but not much. Below are the graphs for the accuracy and DSC over epochs. It just gets stuck and keeps predicting black masks, even with GeneralizedDiceLoss and alternative activation functions to compensate for the class imbalance. I also have tried the sliding window inference predictor and it had a similar output graph but slightly smoother.

image

jillianlee commented 4 years ago

I'm also getting same behavior on the 2D HighResNet and VNet.

From what I've read on other forums its a common problem for class imbalanced problems. The suggested fixes include changing the activation functions to tanh or elu, which I have tried. Or changing the loss functions to Generalized Dice Loss, Tversky Loss, or Masked Dice loss. These adjustments haven't helped either.

kissievendor commented 4 years ago

I am working on something similar. Could you share the whole code for ConvertLabel.

I think this also would be a nice addition to MONAI.

jillianlee commented 4 years ago

Hi @kissievendor

class ConvertLabel(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WC (Whole tumor)*** we want WC because flair shows whole tumor
    and ET (Enhancing tumor).
   """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            # merge labels 1, 2 and 3 to construct WC
            result = np.logical_or(np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1)
            d[key] = result.astype(np.float32)
        return d

Here is the whole code for ConvertLabel. Its modified from the BraTS example that they had previously posted in their examples folder (I believe its been taken down now?). I set it up this way to output only 1 binary label for 1 output channel. This combination of labels is mentioned in the BraTS reference paper. It is the combination that is visible from the FLAIR data.

kissievendor commented 4 years ago

Hi @kissievendor

class ConvertLabel(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WC (Whole tumor)*** we want WC because flair shows whole tumor
    and ET (Enhancing tumor).
   """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            # merge labels 1, 2 and 3 to construct WC
            result = np.logical_or(np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1)
            d[key] = result.astype(np.float32)
        return d

Here is the whole code for ConvertLabel. Its modified from the BraTS example that they had previously posted in their examples folder (I believe its been taken down now?). I set it up this way to output only 1 binary label for 1 output channel. This combination of labels is mentioned in the BraTS reference paper. It is the combination that is visible from the FLAIR data.

Thanks! Where does the MapTransform come from? It gives an error for me.

Have you had more luck with your training / loss?

jillianlee commented 4 years ago

Oh you have to import MapTransform from monai transforms. Here are all my imports (although some are unused):

import monai
from monai.transforms import (Compose, LoadNiftid, LoadNifti, AddChanneld, AddChannel, ScaleIntensityRanged, CropForegroundd,
    Resized, ToNumpyd, SqueezeDimd, RandSpatialCropd, RandSpatialCropSamplesd, MapTransform, Spacingd, Orientationd, ScaleIntensityd, ToNumpyd, ToTensor, RandFlipd, RandScaleIntensityd, ToNumpy, RandShiftIntensityd) 

No, I've had no luck on 2D. Works fine on 3D though with all the same network and loss settings, so I'm not sure what's going on.

kissievendor commented 4 years ago

Oh you have to import MapTransform from monai transforms. Here are all my imports (although some are unused):

import monai
from monai.transforms import (Compose, LoadNiftid, LoadNifti, AddChanneld, AddChannel, ScaleIntensityRanged, CropForegroundd,
    Resized, ToNumpyd, SqueezeDimd, RandSpatialCropd, RandSpatialCropSamplesd, MapTransform, Spacingd, Orientationd, ScaleIntensityd, ToNumpyd, ToTensor, RandFlipd, RandScaleIntensityd, ToNumpy, RandShiftIntensityd) 

No, I've had no luck on 2D. Works fine on 3D though with all the same network and loss settings, so I'm not sure what's going on.

Oh ofcourse! Thanks. I am working with brain data and very small annotations. 3D could mean more info from other slices? Try training on different orientations (quickNAT) good luck!

jillianlee commented 4 years ago

@kissievendor

So I made some mistakes in my code, and it is performing better now. Not the greatest, but still better!

  1. Try increasing the channels to like (64, 128, 256) or higher.

  2. I cropped out some of the black space and centered more around the brain.

  3. I had made an error in my figure plots in the testing loop, so they were only outputting black. The new code is:

        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title("image")
        plt.imshow(test_data['image'][0, 0, :, :], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title("label")
        plt.imshow(test_data['label'][0, 0, :, :])
        plt.subplot(1, 3, 3)
        plt.title("output")
        plt.imshow((val_outputs).detach().cpu()[0, 0, :, :])
        plt.show()

    This is just inside the model evaluation loop so you can see all of your outputs.

  4. Try Tversky Loss with alpha = 0.3 and beta =0.7

My accuracy is around 0.50 now with Tversky Loss, I'm going to play around with some other ones and hope for more improvement.

Best of luck :)

MUKILAN-2003 commented 1 year ago

@jillianlee

Hey, I am facing the issue the model not training the loss is stuck.

image

Is the model parameter and config are correct

loss_function = DiceLoss(to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
dice_metric = DiceMetric(include_background=True, reduction="mean")
post_transforms = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

model = SegResNet(
        spatial_dims=2,
        blocks_down=[1, 2, 2, 4],
        blocks_up=[1, 1, 1],
        init_filters=16,
        in_channels=3,
        out_channels=1,
        dropout_prob=0.2,
        use_conv_final=True,
        act= "RELU",
        norm=Norm.BATCH,
    ).to(device)

What wrong is done ?

Thank's

aymuos15 commented 1 year ago

Has there been any fix on this? Thank you.