Project-MONAI / tutorials

MONAI Tutorials
https://monai.io/started.html
Apache License 2.0
1.8k stars 673 forks source link

Transformation function AsDiscrete(to_onehot=4) only gives binary labels #641

Closed ying2611 closed 2 years ago

ying2611 commented 2 years ago

Hi, I tried to do a multi-class segmentation with Unet, but the transformation function AsDiscrete(to_onehot=4) that I used to one-hot encode my validation labels only gave binary labels, whereas there are 4 classes (including the background) in the validation labels.

I am wondering if you could please help me with this?

I have also attached the code.

Many thanks

#Loss function and Metric

 dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_label = AsDiscrete(to_onehot=4)
    post_pred = Compose([Activations(softmax=True),AsDiscrete(argmax=True, to_onehot=4)])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=4,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),

    ).to(device)
    loss_function = monai.losses.DiceLoss(to_onehot_y=True, softmax=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
------------------------------------------------------------------------------------------------------

#Training code
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    print('unique labels is',torch.unique(val_labels),'before post process')
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]

                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                    print('unique labels is',torch.unique(val_labels[0]),'after post process')

                    dice_metric(y_pred=val_outputs, y=val_labels)

                metric = dice_metric.aggregate().item()

                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
  -------------------------------------------------------------------------------------------------------
#Output log
1/5, train_loss: 0.8484
2/5, train_loss: 0.8363
3/5, train_loss: 0.8430
4/5, train_loss: 0.8450
5/5, train_loss: 0.8290
epoch 1 average loss: 0.8404
----------
epoch 2/10
1/5, train_loss: 0.8353
2/5, train_loss: 0.8110
3/5, train_loss: 0.8194
4/5, train_loss: 0.8052
5/5, train_loss: 0.8071
epoch 2 average loss: 0.8156
unique labels is tensor([0.0000, 0.9922, 0.9961, 1.0000], device='cuda:0') before post process
unique labels is tensor([0., 1.], device='cuda:0') after post process
unique labels is tensor([0.0000, 0.9922, 0.9961, 1.0000], device='cuda:0') before post process
unique labels is tensor([0., 1.], device='cuda:0') after post process
unique labels is tensor([0.0000, 0.9922, 0.9961, 1.0000], device='cuda:0') before post process
unique labels is tensor([0., 1.], device='cuda:0') after post process
dongyang0122 commented 2 years ago

hi @ChenY2000, the one-hot encoding would convert your 4-class labels into binary labels with 4 channels. And the output label values would be 0 or 1. The behavior is expected. You can refer to the definition in the link below. https://en.wikipedia.org/wiki/One-hot

ying2611 commented 2 years ago

Hi @dongyang0122, thank you so much for your reply. I just double-checked the channels' values, the problem is that after the tensor are transformed into binary labels with 4 channels. Only the first two channels have binary labels, whereas the other two channels only have values of zero, given that the original validation labels have 4 classes in total, namely ([0.0000, 0.9922, 0.9961, 1.0000]).

I am wondering if you could help me to sort out this?

The code is attached.

Many thanks

#Partial code from the training section
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    print('unique labels is',torch.unique(val_labels),'before post process')
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]

                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]

                    print('unique labels is',torch.unique(val_labels[0][0]),'after post process')

                    print('Values for first channel after onhot encode is',torch.unique(val_labels[0][0]))
                    print('Values for second channel after onhot encode is',torch.unique(val_labels[0][1]))
                    print('Values for third channel after onhot encode is',torch.unique(val_labels[0][2]))
                    print('Values for fourth channel after onhot encode is',torch.unique(val_labels[0][3]))

---------------------------------------------------------------------------------------------------------------
#Output log
unique labels is tensor([0.0000, 0.9922, 0.9961, 1.0000], device='cuda:0') before post process
unique labels is tensor([0., 1.], device='cuda:0') after post process
Values for first dimention after onhot encode is tensor([0., 1.], device='cuda:0')
Values for second channel after onhot encode is tensor([0., 1.], device='cuda:0')
Values for third channel after onhot encode is tensor([0.], device='cuda:0')
Values for fourth channel after onhot encode is tensor([0.], device='cuda:0')
dongyang0122 commented 2 years ago

@ChenY2000 if the task is segmentation, why do you have floating numbers in the labels? Normally the labels are integer values.

ying2611 commented 2 years ago

@dongyang0122 I used the 2d_segmentation tutorial to do the 2D multi-class segmentation by changing the num_seg_classes to 3 when creating the array dataset. The output from the data loader was automatically floating numbers.

Do I change the labels to integer before transforming them?

Sorry for these many questions.

Many thanks


def main(tempdir):
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # create a temporary directory and 40 random image, mask pairs
    print(f"generating synthetic data to {tempdir} (this may take a while)")
    for i in range(40):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=3)
        Image.fromarray((im * 255).astype("uint8")).save(os.path.join(tempdir, f"img{i:d}.png"))
        Image.fromarray((seg * 255).astype("uint8")).save(os.path.join(tempdir, f"seg{i:d}.png"))

    images = sorted(glob(os.path.join(tempdir, "img*.png")))
    segs = sorted(glob(os.path.join(tempdir, "seg*.png")))
------------------------------------------------------------------------------------------------------------------
    # define transforms for image and segmentation
    train_imtrans = Compose(
        [
            LoadImage(image_only=True),
            AddChannel(),
            ScaleIntensity(),
            RandSpatialCrop((96, 96), random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
            EnsureType(),
        ]
    )
    train_segtrans = Compose(
        [
            LoadImage(image_only=True),
            AddChannel(),
            ScaleIntensity(),
            RandSpatialCrop((96, 96), random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
            EnsureType(),
        ]
    )
    val_imtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])
    val_segtrans = Compose([LoadImage(image_only=True), AddChannel(), ScaleIntensity(), EnsureType()])

    # define array dataset, data loader
    check_ds = ArrayDataset(images, train_imtrans, segs, train_segtrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
    im, seg = monai.utils.misc.first(check_loader)
    print(im.shape, seg.shape)
------------------------------------------------------------------------------------------------------------------
    # create a training data loader
    train_ds = ArrayDataset(images[:20], train_imtrans, segs[:20], train_segtrans)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
    # create a validation data loader
    val_ds = ArrayDataset(images[-20:], val_imtrans, segs[-20:], val_segtrans)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_label = AsDiscrete(to_onehot=4)
    post_pred = AsDiscrete(argmax=True, to_onehot=4)
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=4,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),

    ).to(device)
    loss_function = monai.losses.DiceLoss(to_onehot_y=True, softmax=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
------------------------------------------------------------------------------------------------------------------
    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            print(torch.unique(labels))
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    roi_size = (96, 96)
                    sw_batch_size = 4
                    # print('unique labels is',torch.unique(val_labels),'before post process')
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]

                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]

                    # print('unique labels is',torch.unique(val_labels[0][0]),'after post process')

                    # print('Values for first dimention after onhot encode is',torch.unique(val_labels[0][0]))
                    # print('Values for second dimention after onhot encode is',torch.unique(val_labels[0][1]))
                    # print('Values for third dimention after onhot encode is',torch.unique(val_labels[0][2]))
                    # print('Values for fourth dimention after onhot encode is',torch.unique(val_labels[0][3]))

                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

if __name__ == "__main__":
    with tempfile.TemporaryDirectory() as tempdir:
        main(tempdir)
``
ying2611 commented 2 years ago

@dongyang0122 . Thank you so much for your help and the hint, I finally got the right way of doing it.

Best wishes

NastaranVB commented 6 months ago

@ChenY2000 I have also faced with this issue. Could you please tell how did you solved it?