Project-MONAI / MONAILabel

MONAI Label is an intelligent open source image labeling and learning tool.
https://docs.monai.io/projects/label
Apache License 2.0
623 stars 195 forks source link

BUG: Fix traceback in restored transform #1766

Closed ThomasKierski closed 1 month ago

ThomasKierski commented 1 month ago

Using transforms from MONAI and MONAILabel can result in metadata containing torch.tensors and torch.Sizes. This change fixes a traceback in the restored transform resulting from an incompatibility between torch datatypes and numpy.any().

Signed-off-by: Thomas Kierski thomas.kierski@revvity.com

ThomasKierski commented 1 month ago

A snippet to reproduce the transform pipeline that resulted in this traceback is below:

    pre_transforms = Compose([
        LoadImaged(keys=("image", "label"), reader="ITKReader"),
        EnsureChannelFirstd(keys=("image", "label")),
        NormalizeLabelsInDatasetd(keys="label", label_names=label_names),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(keys=("image"), a_min=-1000, a_max=3500, b_min=0.0, b_max=1.0, clip=True),
        Resized(keys=("image", "label"), spatial_size=3*[128], mode=("area", "nearest")),
    ])

    post_transforms = Compose([
        Activationsd(keys=("pred"), softmax=True),
        AsDiscreted(
            keys=("pred", "label"),
            argmax=(True, False),
            to_onehot=num_labels,
        ),
        SplitPredsLabeld(keys="pred"),
    ])

    data = {"image": "/path/to/image.mha", "label":"/path/to/label.seg.nrrd"}
    data = pre_transforms(data)
    data["pred"] = model(data["image"])
    restorex = Restored(keys=["pred"], ref_image="image")
    restored_prediction = restorex(data)

The traceback in question:

TypeError Traceback (most recent call last) Cell In[10], line 81 79 data = monai.data.decollate_batch(data) 80 for n in range(len(data)): ---> 81 data[n] = restorex(data[n]) 82 data = monai.data.list_data_collate(data) 83 # write_sample_to_disk(data, os.path.join(outpath,os.path.basename(data["label"].meta["filename_or_obj"][0])))

File /usr/local/lib/python3.10/dist-packages/monailabel/transform/post.py:132, in Restored.call(self, data) 129 spatial_size = spatial_shape[-len(current_size) :] 131 # Undo Spacing --> 132 if np.any(np.not_equal(current_size, spatial_size)): 133 resizer = Resize(spatial_size=spatial_size, mode=self.mode[idx]) 134 result = resizer(result, mode=self.mode[idx], align_corners=self.align_corners[idx])

File <__array_function__ internals>:200, in any(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/numpy/core/fromnumeric.py:2423, in any(a, axis, out, keepdims, where) 2333 @array_function_dispatch(_any_dispatcher) 2334 def any(a, axis=None, out=None, keepdims=np._NoValue, *, where=np._NoValue): 2335 """ 2336 Test whether any array element along a given axis evaluates to True. 2337 (...) 2421 2422 """ -> 2423 return _wrapreduction(a, np.logical_or, 'any', axis, None, out, 2424 keepdims=keepdims, where=where)

File /usr/local/lib/python3.10/dist-packages/numpy/core/fromnumeric.py:84, in _wrapreduction(obj, ufunc, method, axis, dtype, out, kwargs) 82 return reduction(axis=axis, dtype=dtype, out=out, passkwargs) 83 else: ---> 84 return reduction(axis=axis, out=out, passkwargs) 86 return ufunc.reduce(obj, axis, dtype, out, passkwargs)

TypeError: any() received an invalid combination of arguments - got (out=NoneType, axis=NoneType, ), but expected one of: