notmatthancock / pylidc

An object relational mapping for the LIDC dataset using sqlalchemy.
https://pylidc.github.io
Other
104 stars 40 forks source link

Majorly slow dataloader using pylidc #68

Open joangog opened 4 months ago

joangog commented 4 months ago

Hi! I am trying to implement a dataloader for the LIDC dataset using pylidc. Unfortunately there is major slowdown (20 seconds per epoch to 15 minutes per epoch) in training my model when using this dataloader. I wonder if the issue is that the loading data/annotation function of a specific sample is a bit slow in pylidc. The augmentations are for sure not slow because I apply them on a different dataset (BraTS) and I have no problems. Does anyone know what the issue is, and what a solution would be to implement an efficient dataloader? Currectly this is what I have:

class LidcFineTune(Dataset):
    def __init__(self, config, img_list, crop_size=(128, 128, 64), train=False):  # 64 because some raw data don't have many slices
        self.config = config
        self.train = train
        self.img_list = img_list
        self.crop_size = crop_size

        # Create setup file for pylidc
        txt = f"""
        [dicom]
        path = {config.data}
        warn = True
        """
        with open(os.path.join(os.path.expanduser('~'),'.pylidcrc'), 'w') as file:
            file.write(txt)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, index):
        pid = self.img_list[index]
        scan = pl.query(pl.Scan).filter(pl.Scan.patient_id == pid).first()
        ann = pl.query(pl.Annotation).filter(pl.Scan.patient_id == pid).first()

        # Image
        try:
            x = torch.FloatTensor(scan.to_volume())
        except Exception as e:
            raise RuntimeError(f"Corrupted file in {pid}. Redownload!") from e

        # Segmentation mask
        y = torch.zeros(x.shape)
        mask = torch.FloatTensor(ann.boolean_mask())
        bbox = ann.bbox()
        y[bbox[0].start:bbox[0].stop,bbox[1].start:bbox[1].stop,bbox[2].start:bbox[2].stop] = mask

        # Resize
        x = x.T.unsqueeze(1) # Move slice dim to batch dim and add temporary channel dimension (H x W x D) -> (D x 1 x H x W)
        y = y.T.unsqueeze(1)
        x = f.interpolate(x, scale_factor=(0.5,0.5))  # Scale only height and weight, not slice dim
        y = f.interpolate(y, scale_factor=(0.5,0.5))
        x = x.permute(1,2,3,0)  # Put slice dim last (D x 1 x H x W -> 1 x H x W x D)
        y = y.permute(1,2,3,0)

        x, y = self.aug_sample(x, y)

        # min max
        x = self.normalize(x)

        return x, y

    def aug_sample(self, x, y):
        if self.train:
            # Random crop and augment
            x, y = self.random_crop(x, y)
            if random.random() < 0.5:
                x = torch.flip(x, dims=(1,))  # torch.flip not the source of the major slowdown
                y = torch.flip(y, dims=(1,))
            if random.random() < 0.5:
                x = torch.flip(x, dims=(2,))
                y = torch.flip(y, dims=(2,))
            if random.random() < 0.5:
                x = torch.flip(x, dims=(3,))
                y = torch.flip(y, dims=(3,))
        else:
            # Center crop
            x, y = self.center_crop(x, y)

        return x, y

    def random_crop(self, x, y):
        """
        Args:
            x: 4d array, [channel, h, w, d]
        """
        crop_size = self.crop_size
        height, width, depth = x.shape[-3:]
        sx = random.randint(0, height - crop_size[0] - 1)
        sy = random.randint(0, width - crop_size[1] - 1)
        sz = random.randint(0, depth - crop_size[2] - 1)
        crop_volume = x[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]
        crop_seg = y[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]

        return crop_volume, crop_seg

    def center_crop(self, x, y):
        crop_size = self.crop_size
        height, width, depth = x.shape[-3:]
        sx = (height - crop_size[0] - 1) // 2
        sy = (width - crop_size[1] - 1) // 2
        sz = (depth - crop_size[2] - 1) // 2
        crop_volume = x[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]
        crop_seg = y[:, sx:sx + crop_size[0], sy:sy + crop_size[1], sz:sz + crop_size[2]]

        return crop_volume, crop_seg

    def normalize(self, x):
        return (x - x.min()) / (x.max() - x.min())
notmatthancock commented 4 months ago

You're repeatedly computing scan.to_volume() and ann.boolean_mask() in your augmentations pipeline. to_volume() reads all the DICOM from disk and converts it to a numpy volume. ann.boolean_mask() does ray-casting for each pixel to determine if it is inside or outside of the annotation contour. Both these things are pretty expensive. If you are generating multiple augmentations from the same index in __get_item__ I would recommend caching the base data used to generate those augmentations.

Emvlt commented 1 month ago

Hello! We are working on making a pytorch compatible dataloader for LIDC-IDRI: please check out https://github.com/CambridgeCIA/LION/tree/main/LION/data_loaders for direct use or some inspiration to write your own :) Currently, we use the dataset in 2D and preprocess it for this purpose, and to avoid the expensive calls mentioned by @notmatthancock