Project-MONAI / tutorials

MONAI Tutorials
https://monai.io/started.html
Apache License 2.0
1.85k stars 682 forks source link

Cannot use decollate on torch tensors and if i want to stack meta tensors then i cant do so because i can't account for different metas in the meta tensor #1620

Open 0tist opened 9 months ago

0tist commented 9 months ago

Bug Description I'm trying to use a custom collate function for a dataloader, while I create the custom dataloader, the scans are fetched randomly from a collection which is gathered from different machines(like MRI, CT machines). I want to preserve the meta for all these tensors, since I cant do that with monai's meta tensor, I made a separate list to store meta with index and then I passed the collated torch tensors.

Code for custom collate fn ` def val_datalist_collate(self, batch):

    imgs = []
    labels = []
    batch_output = {}
    for sample in batch:
        imgs.append(sample['image'])
        lbl = sample['label']
        if self.enable_binary_class:
            lbl[lbl > 0] = 1
        labels.append(lbl)

    batch_output['image'] = torch.stack(imgs)
    batch_output['label'] = torch.stack(labels)

    return batch_output

`

To Reproduce Make a custom torch tensor dataloader and then decollate the batch loaded from the dataloader. Btw i loaded decollate from monai.data from monai.data import decollate_batch

Error log

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[7], [line 1](vscode-notebook-cell:?execution_count=7&line=1)
----> [1](vscode-notebook-cell:?execution_count=7&line=1) mm.train()

File [~/opet/opet/src/main.py:210](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/main.py:210), in ModelMaker.train(self)
    [207](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/main.py:207)             mlflow.log_param(f'{k}[/](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/){k_}', v_)
    [209](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/main.py:209) ##### RUN TRAINING ######
--> [210](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/main.py:210) self.trainer.train(self.exp_id, run_id)

File [~/opet/opet/src/segmentation_3D/core/train.py:70](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:70), in trainer.train(self, exp_id, run_id)
     [68](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:68) patience -= 1
     [69](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:69) if not(patience):
---> [70](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:70)     mean_dice_val, val_loss = self.validate(val_loader)
     [71](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:71)     state_dict = {'exp_id': exp_id,
     [72](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:72)                   'run_id': run_id,
     [73](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:73)                   'optimizer': self.optimizer.state_dict(),
     [74](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:74)                   'training_loss': epoch_tr_loss,
     [75](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:75)                   'val loss': val_loss,
     [76](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:76)                   'epoch': epoch+1}
     [77](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:77)     self.save_model(dice_val= mean_dice_val, **state_dict)

File [~/opet/opet/src/segmentation_3D/core/train.py:108](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:108), in trainer.validate(self, val_loader)
    [106](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:106)         x_val, y_val = val_sample['image'], val_sample['label']
    [107](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:107)         x_val, y_val = x_val.to(self.device), y_val.to(self.device)
--> [108](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:108)         loss_val = self.val_one_iter(x_val, y_val)
    [109](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:109)         epoch_val_loss += loss_val
    [111](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:111) dice_val = self.dice_metric.aggregate().item()

File [~/opet/opet/src/segmentation_3D/core/train.py:130](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:130), in trainer.val_one_iter(self, X, Y)
    [127](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:127) post_pred = AsDiscrete(argmax=True, to_onehot=self.n_classes)
    [129](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:129) val_outputs = sliding_window_inference(X, self.patch_size, self.num_samples, self.model)
--> [130](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:130) val_labels_list = decollate_batch(Y)
    [131](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:131) val_labels_convert = [
    [132](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:132)     post_label(val_label_tensor) for val_label_tensor in val_labels_list
    [133](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:133) ]
    [134](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/opet/opet/src/segmentation_3D/core/train.py:134) val_outputs_list = decollate_batch(val_outputs)

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:619](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:619), in decollate_batch(batch, detach, pad, fill_value)
    [617](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:617) # if of type MetaObj, decollate the metadata
    [618](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:618) if isinstance(batch, MetaObj):
--> [619](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:619)     for t, m in zip(out_list, decollate_batch(batch.meta)):
    [620](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:620)         if isinstance(t, MetaObj):
    [621](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:621)             t.meta = m

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631), in decollate_batch(batch, detach, pad, fill_value)
    [628](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:628)         return [t.item() for t in out_list]
    [629](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:629)     return list(out_list)
--> [631](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631) b, non_iterable, deco = _non_zipping_check(batch, detach, pad, fill_value)
    [632](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:632) if b <= 0:  # all non-iterable, single item "batch"? {"image": 1, "label": 1}
    [633](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:633)     return deco

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532), in _non_zipping_check(batch_data, detach, pad, fill_value)
    [530](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:530) _deco: Mapping | Sequence
    [531](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:531) if isinstance(batch_data, Mapping):
--> [532](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532)     _deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data}
    [533](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:533) elif isinstance(batch_data, Iterable):
    [534](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:534)     _deco = [decollate_batch(b, detach, pad=pad, fill_value=fill_value) for b in batch_data]

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532), in <dictcomp>(.0)
    [530](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:530) _deco: Mapping | Sequence
    [531](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:531) if isinstance(batch_data, Mapping):
--> [532](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532)     _deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data}
    [533](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:533) elif isinstance(batch_data, Iterable):
    [534](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:534)     _deco = [decollate_batch(b, detach, pad=pad, fill_value=fill_value) for b in batch_data]

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631), in decollate_batch(batch, detach, pad, fill_value)
    [628](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:628)         return [t.item() for t in out_list]
    [629](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:629)     return list(out_list)
--> [631](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:631) b, non_iterable, deco = _non_zipping_check(batch, detach, pad, fill_value)
    [632](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:632) if b <= 0:  # all non-iterable, single item "batch"? {"image": 1, "label": 1}
    [633](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:633)     return deco

File [~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:534](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:534), in _non_zipping_check(batch_data, detach, pad, fill_value)
    [532](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:532)     _deco = {key: decollate_batch(batch_data[key], detach, pad=pad, fill_value=fill_value) for key in batch_data}
    [533](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:533) elif isinstance(batch_data, Iterable):
--> [534](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:534)     _deco = [decollate_batch(b, detach, pad=pad, fill_value=fill_value) for b in batch_data]
    [535](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:535) else:
    [536](https://vscode-remote+ssh-002dremote-002baiatella100.vscode-resource.vscode-cdn.net/home/jayesh/opet/nbs/~/miniconda3/envs/ailib/lib/python3.10/site-packages/monai/data/utils.py:536)     raise NotImplementedError(f"Unable to de-collate: {batch_data}, type: {type(batch_data)}.")

TypeError: iteration over a 0-d array
KumoLiu commented 9 months ago

Hi @0tist, I think the list_data_collate in MONAI can collate MetaTensor. https://github.com/Project-MONAI/MONAI/blob/facf17693410d41170edd8e94364b4f341369aea/monai/data/utils.py#L505