Closed BelieferQAQ closed 1 year ago
Hi @BelieferQAQ, I think you should check your data after train_transforms
.
In case you didn't find this tutorial.
https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb
Thanks!
Please use MONAI's Discussions tab For questions relating to MONAI usage, please do not create an issue.
Instead, use MONAI's GitHub Discussions tab. This can be found next to Issues and Pull Requests along the top of our repository. Hello, I'm trying to follow a tutorial on segmenting lymph nodes using the BTCV dataset with the Swin UNet network. The lymph nodes are relatively small, and during the training process, I'm getting a warning about sample imbalance. Additionally, the Dice score on the validation set is consistently low, often close to 0.
The entire code is essentially the same as the tutorial, and the transforms used are as follows: train_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), ScaleIntensityRanged( keys=["image"], a_min=-250, a_max=500, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=2, image_key="image", image_threshold=0, ), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), ScaleIntensityRanged( keys=["image"], a_min=-250, a_max=500, b_min=0.0, b_max=1.0, clip=True, ), CropForegroundd(keys=["image", "label"], source_key="image"), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ] )
device = torch.device("cuda:2") model = SwinUNETR( img_size=(96, 96, 96), in_channels=1, out_channels=2, feature_size=48, use_checkpoint=True, ).to(device)
weight = torch.load("/data/jupyter/wd/swinUnetr-lymphNode/pretrain_models/model_swinvit.pt") model.load_from(weights=weight) print("Using pretrained self-supervied Swin UNETR backbone weights !")
root_dir = "/data/jupyter/wd/swinUnetr-lymphNode/save_models_swin" max_epochs = 500 val_interval = 5 best_metric = -1 best_metric_epoch = -1 epoch_loss_values = [] metric_values = [] post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)]) post_label = Compose([AsDiscrete(to_onehot=2)])
loss_function = DiceLoss(to_onehot_y=True, softmax=True) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs) dice_metric = DiceMetric(include_background=False, reduction="mean")
for epoch in range(max_epochs): print("-" * 10) print(f"epoch {epoch + 1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = ( batch_data["image"].to(device), batch_data["label"].to(device), ) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}") scheduler.step() epoch_loss /= step epoch_loss_values.append(epoch_loss) print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")