fepegar / torchio

Medical imaging toolkit for deep learning
https://torchio.org
Apache License 2.0
2.07k stars 240 forks source link

Silenced exception makes it harder to debug custom Transforms #1098

Closed Zhack47 closed 1 year ago

Zhack47 commented 1 year ago

Is there an existing issue for this?

Bug summary

In the _get_next_subject() method of the Queue class, there is a tr / except statement which goes like this:

       try:
            subject = next(self.subjects_iterable)
        except StopIteration as exception:
            self._print('Queue is empty:', exception)
            self._initialize_subjects_iterable()
            subject = next(self.subjects_iterable)
        except AssertionError as exception:
            if 'can only test a child process' in str(exception):
                message = (
                    'The number of workers for the data loader used to pop'
                    ' patches from the queue should be 0. Is it?'
                )
                raise RuntimeError(message) from exception
        return subject

When an AssertionError arises and the if condition is not fulfilled, we get an UnboundLocalError telling us subject is not defined. The actual exception is lost, and this makes debugging harder. In my case the AssertionError was : AssertionError: Output of SimulateLowResolutionTransform is 5D which explicits betterr whatt my problem was.

In order to remove the confusion, we could raise the original AssertionError if it does not fulfill the if statement

Code for reproduction

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchio.data.sampler.label import LabelSampler
from batchgenerators.augmentations.resample_augmentations import augment_linear_downsampling_scipy
from torchio import DATA, TYPE, LABEL, INTENSITY, IntensityTransform

class SimulateLowResolutionTransform(IntensityTransform):

    def __init__(self, zoom_range=(0.5, 1), per_channel=False, p_per_channel=1,
                 channels=None, order_downsample=1, order_upsample=0, data_key="data", p_per_sample=1,
                 ignore_axes=None):
        super().__init__(p_per_sample)
        self.order_upsample = order_upsample
        self.order_downsample = order_downsample
        self.channels = channels
        self.per_channel = per_channel
        self.p_per_channel = p_per_channel
        self.p_per_sample = p_per_sample
        self.data_key = data_key
        self.zoom_range = zoom_range
        self.ignore_axes = ignore_axes

    def apply_transform(self, subject):
        keys = sorted(subject.keys())
        for key in keys:
            if subject[key][TYPE] == INTENSITY:
                data = subject[key][DATA].unsqueeze(1).numpy()
                for i in range(data.shape[0]):
                    if np.random.uniform() < self.p_per_sample:
                        data[i] = augment_linear_downsampling_scipy(data[i], zoom_range=self.zoom_range,
                                                                    per_channel=self.per_channel,
                                                                    p_per_channel=self.p_per_channel,
                                                                    channels=self.channels,
                                                                    order_downsample=self.order_downsample,
                                                                    order_upsample=self.order_upsample,
                                                                    ignore_axes=self.ignore_axes)
                subject[key][DATA] = torch.tensor(data)
        return subject

if __name__ == "__main__":
    import torchio as tio
    st = SimulateLowResolutionTransform(zoom_range=(.05, 1.), per_channel=True,
                                                     p_per_channel=1.,
                                                     order_downsample=0, order_upsample=0, p_per_sample=1.,
                                                     ignore_axes=None)
    colin_dataset = tio.datasets.mni.Colin27()

    ds_train = tio.SubjectsDataset([colin_dataset], transform=st)
    sampler = LabelSampler((120, 120, 80))
    patches_queue_train = tio.Queue(ds_train, max_length=32, samples_per_volume=4, sampler=sampler,
                                    shuffle_patches=True, shuffle_subjects=True, num_workers=8)

    training_loader = DataLoader(patches_queue_train, batch_size=2, shuffle=True)
    for batch in training_loader:
        print(batch["t1"][DATA].shape)

Actual outcome

Traceback (most recent call last): File "/home/zhack/Documents/THESE/4Net/fournet/utils/transforms/augmentations/spatial_augments.py", line 304, in for batch in training_loader: File "/home/zhack/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 628, in next data = self._next_data() File "/home/zhack/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "/home/zhack/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/zhack/.local/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/zhack/.local/lib/python3.10/site-packages/torchio/data/queue.py", line 170, in getitem self._fill() File "/home/zhack/.local/lib/python3.10/site-packages/torchio/data/queue.py", line 229, in _fill subject = self._get_next_subject() File "/home/zhack/.local/lib/python3.10/site-packages/torchio/data/queue.py", line 270, in _get_next_subject return subject UnboundLocalError: local variable 'subject' referenced before assignment

Error messages

UnboundLocalError: local variable 'subject' referenced before assignment

Expected outcome

AssertionError: Output of SimulateLowResolutionTransform is 5D

System info

pegar/torchio/main/print_system.py)
Platform:   Linux-5.15.0-79-generic-x86_64-with-glibc2.35
TorchIO:    0.18.78
PyTorch:    1.13.1+cu117
SimpleITK:  2.1.1.2 (ITK 5.2)
NumPy:      1.22.0
Python:     3.10.6 (main, May 29 2023, 11:10:38) [GCC 11.3.0]
fepegar commented 1 year ago

Hi, @Zhack47. Can you please share a minimal example I can reproduce?

Zhack47 commented 1 year ago

This is more minimal and should trigger the bug (tested on the same machine as above)

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchio.data.sampler.label import LabelSampler
from torchio import DATA, TYPE, LABEL, INTENSITY, IntensityTransform

class SimulateLowResolutionTransform(IntensityTransform):
    def __init__(self):
        super().__init__(1)

    def apply_transform(self, subject):
        keys = sorted(subject.keys())
        for key in keys:
                subject[key][DATA] = subject[key][DATA].unsqueeze(0)
        return subject

if __name__ == "__main__":
    import torchio as tio
    st = SimulateLowResolutionTransform()
    colin_dataset = tio.datasets.mni.Colin27()
    ds_train = tio.SubjectsDataset([colin_dataset], transform=st)
    sampler = LabelSampler((120, 120, 80))
    patches_queue_train = tio.Queue(ds_train, max_length=32, samples_per_volume=4, sampler=sampler,
                                    shuffle_patches=True, shuffle_subjects=True, num_workers=8)

    training_loader = DataLoader(patches_queue_train, batch_size=2, shuffle=True)
    for batch in training_loader:
        print(batch["t1"][DATA].shape)
fepegar commented 1 year ago

Thanks, @Zhack47. Good catch! I think adding raise exception after line 330 would do. Do you agree? Would you like to contribute with a PR?

Zhack47 commented 1 year ago

Thanks! I completely agree with this solution I am going to make a PR to fix this !

fepegar commented 1 year ago

Fixed in v0.19.1.