Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.71k stars 1.04k forks source link

labels and image not overlapping / error runtime in torch.cat #6863

Closed Striking-Project closed 1 year ago

Striking-Project commented 1 year ago

Discussed in https://github.com/Project-MONAI/MONAI/discussions/6861

Originally posted by **Striking-Project** August 13, 2023 Hi @KumoLiu these are the transformations I am using on data I collected. num_samples = 2 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_transforms = Compose( [ LoadImaged(keys=["image", "label"], ensure_channel_first=True), EnsureTyped(keys=["image", "label"], device=device, track_meta=False), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=(96, 96, 96), pos=1, neg=1, num_samples=num_samples, image_key="image", image_threshold=0, allow_smaller=True, ), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"], ensure_channel_first=True), EnsureTyped(keys=["image", "label"], device=device, track_meta=True), ] ) I'm trying to train the swin_unetr from monai btcv segmentation the code for plotting the data : slice_map = { "ci1clear_PELVIS_20210104171404_11.nii.gz": 19, "ci2clear_PELVIS_20210104161217_6.nii.gz": 13, "ci3clear_PELVIS_20210106090948_6.nii.gz": 15, "ci4clear_PELVIS_ABIR_20210107182848_11.nii.gz": 17, "ci5clear_PELVIS_20210108135040_9.nii.gz": 21, "ci6clear_PELVIS_ABIR_20210109094218_7.nii.gz": 29, "ci7clear_PELVIS_ABIR_20210111122138_8.nii.gz": 21, "c8clear_PELVIS_ABIR_20210111132001_8.nii.gz": 120, "c9_PELVIS_ABIR_20210115141302_5.nii.gz": 120, "c10clear_PELVIS_20210118143626_6.nii.gz": 15 } case_num = 3 img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1] img = val_ds[case_num]["image"] label = val_ds[case_num]["label"] img_shape = img.shape label_shape = label.shape print(f"image shape: {img_shape}, label shape: {label_shape}") plt.figure("image", (18, 6)) plt.subplot(1, 2, 1) plt.title("image") plt.imshow(img[0, :, :, slice_map[img_name]].detach().cpu(), cmap="gray") plt.subplot(1, 2, 2) plt.title("label") plt.imshow(label[0, :, :, slice_map[img_name]].detach().cpu()) plt.show() the label needs to be rotated 180 degrees clockwise to overlap on the data. I get the same issue when I convert the niffti segmentation file to png. But It shouldn't be that way. ![image](https://github.com/Project-MONAI/MONAI/assets/80620671/c040dffc-5e4a-4905-a55e-72d6a8be194d) I also think it's causing me another major problem which the dimensions not matching in torch.cat in the decoder block --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) [](https://localhost:8080/#) in () 10 metric_values = [] 11 while global_step < max_iterations: ---> 12 global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best) 13 model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth"))) 6 frames [](https://localhost:8080/#) in train(global_step, train_loader, dice_val_best, global_step_best) 41 x, y = (batch["image"].cuda(), batch["label"].cuda()) 42 with torch.cuda.amp.autocast(): ---> 43 logit_map = model(x) 44 loss = loss_function(logit_map, y) 45 scaler.scale(loss).backward() [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs) 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], [] [](https://localhost:8080/#) in forward(self, x_in) 306 enc3 = self.encoder4(hidden_states_out[2]) 307 dec4 = self.encoder10(hidden_states_out[4]) --> 308 dec3 = self.decoder5(dec4, hidden_states_out[3]) 309 dec2 = self.decoder4(dec3, enc3) 310 dec1 = self.decoder3(dec2, enc2) [/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs) 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(*args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], [] [/usr/local/lib/python3.10/dist-packages/monai/networks/blocks/unetr_block.py](https://localhost:8080/#) in forward(self, inp, skip) 82 # number of channels for skip should equals to out_channels 83 out = self.transp_conv(inp) ---> 84 85 if out.shape[1] != skip.shape[1]: 86 skip = torch.nn.functional.interpolate(skip, size=out.shape[2:], mode='nearest') [/usr/local/lib/python3.10/dist-packages/monai/data/meta_tensor.py](https://localhost:8080/#) in __torch_function__(cls, func, types, args, kwargs) 280 if kwargs is None: 281 kwargs = {} --> 282 ret = super().__torch_function__(func, types, args, kwargs) 283 # if `out` has been used as argument, metadata is not copied, nothing to do. 284 # if "out" in kwargs: [/usr/local/lib/python3.10/dist-packages/torch/_tensor.py](https://localhost:8080/#) in __torch_function__(cls, func, types, args, kwargs) 1293 1294 with _C.DisableTorchFunctionSubclass(): -> 1295 ret = func(*args, **kwargs) 1296 if func in get_default_nowrap_functions(): 1297 return ret RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 3 for tensor number 1 in the list.
KumoLiu commented 1 year ago

Duplicated with https://github.com/Project-MONAI/MONAI/discussions/6861, closing for now.