Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.71k stars 1.04k forks source link

Please use MONAI Discussion tab for questions #6987

Closed BelieferQAQ closed 1 year ago

BelieferQAQ commented 1 year ago

Please use MONAI's Discussions tab For questions relating to MONAI usage, please do not create an issue.

Instead, use MONAI's GitHub Discussions tab. This can be found next to Issues and Pull Requests along the top of our repository. Hello, I'm trying to follow a tutorial on segmenting lymph nodes using the BTCV dataset with the Swin UNet network. The lymph nodes are relatively small, and during the training process, I'm getting a warning about sample imbalance. Additionally, the Dice score on the validation set is consistently low, often close to 0. image

The entire code is essentially the same as the tutorial, and the transforms used are as follows: train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), ScaleIntensityRanged( keys=["image"], a_min=-250, a_max=500, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=2, image_key="image", image_threshold=0, ), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), ScaleIntensityRanged( keys=["image"], a_min=-250, a_max=500, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ] )

device = torch.device("cuda:2") model = SwinUNETR( img_size=(96, 96, 96), in_channels=1, out_channels=2, feature_size=48, use_checkpoint=True, ).to(device)

weight = torch.load("/data/jupyter/wd/swinUnetr-lymphNode/pretrain_models/model_swinvit.pt") model.load_from(weights=weight) print("Using pretrained self-supervied Swin UNETR backbone weights !")

root_dir = "/data/jupyter/wd/swinUnetr-lymphNode/save_models_swin" max_epochs = 500 val_interval = 5 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = [] metric_values = [] post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([AsDiscrete(to_onehot=2)])

loss_function = DiceLoss(to_onehot_y=True, softmax=True) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs) dice_metric = DiceMetric(include_background=False, reduction="mean")

for epoch in range(max_epochs): print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = ( batch_data["image"].to(device), batch_data["label"].to(device), ) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}") scheduler.step() epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

if (epoch + 1) % val_interval == 0:
    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_inputs, val_labels = (
                val_data["image"].to(device),
                val_data["label"].to(device),
            )
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
            val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
            val_labels = [post_label(i) for i in decollate_batch(val_labels)]
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)

        # aggregate the final mean dice result
        metric = dice_metric.aggregate().item()
        # reset the status for next validation round
        dice_metric.reset()

        metric_values.append(metric)
        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_model_s1new_s5_val.pth"))
            print("saved new best metric model")
        print(
            f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
            f"\nbest mean dice: {best_metric:.4f} "
            f"at epoch: {best_metric_epoch}"
        )
KumoLiu commented 1 year ago

Hi @BelieferQAQ, I think you should check your data after train_transforms. In case you didn't find this tutorial. https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb

Thanks!