DeepTrackAI / DeepTrack2

DeepTrack2
MIT License
157 stars 47 forks source link

crop with sources #201

Closed cmanzo closed 7 months ago

cmanzo commented 8 months ago

I'm trying to figure out a way to deal with the UNet example and the ssTEM dataset using dt.sources. Segmented and raw images are in different folders and I'm trying to have a pipeline that does parallel operations on corresponding images of both folders.

I couldn'f find a way to do that with dt.Dataset, so I've solved it using a pytorch dataset but I'm not sure is the most efficient way:

class SegmentationDataset(Dataset):
    def __init__(self, imagePaths, labelPaths, pipeline):
        self.imagePaths = imagePaths
        self.labelPaths = labelPaths
        self.pipeline = pipeline

    def __len__(self):
        return len(self.imagePaths)

    def __getitem__(self, idx):
        image, label = pipeline([self.imagePaths[idx], self.labelPaths[idx]])
        return (image, label)

The Paths are created with dt.sources to allow dt.Flip:

raw_paths = dt.sources.ImageFolder(root=raw_path)
label_paths = dt.sources.ImageFolder(root=label_path)

raw_sources = raw_paths.product(flip_ud=[True, False], flip_lr=[True, False])
label_sources = label_paths.product(flip_ud=[True, False], flip_lr=[True, False])

and the pipeline is formed by the two below:

im_pipeline = (
    dt.LoadImage(raw_sources.path)
    >> dt.NormalizeMinMax()
    >> dt.FlipLR(raw_sources.flip_lr)
    >> dt.FlipUD(raw_sources.flip_ud)
    >> dt.MoveAxis(2, 0)
    >> dt.pytorch.ToTensor(dtype=torch.float)
)

lab_pipeline = (
    dt.LoadImage(label_sources.path)
    >> dt.Lambda(select_labels, class_labels=[255, 191])
    >> dt.FlipLR(label_sources.flip_lr)
    >> dt.FlipUD(label_sources.flip_ud)
    >> dt.MoveAxis(2, 0)
    >> dt.pytorch.ToTensor(dtype=torch.float)
)

Is this the only way to achieve it?

An important point is that I haven't been able to include within or outside the pipeline is the dt.Crop, 'cause I get AttributeError: 'numpy.ndarray' object has no attribute 'properties' I thought it was because of dt.config.disable_image_wrapper() but it doesn't seem to be the case.

BenjaminMidtvedt commented 8 months ago

Should be possible with

im_pipeline = dt.LoadImage(raw_sources.path) >> dt.NormalizeMinMax()
lab_pipeline = dt.LoadImage(label_sources.path) >> dt.Lambda(select_labels, class_labels=[255, 191])

pipeline = (
    (im_pipeline & lab_pipeline)
    >> dt.FlipLR(raw_sources.flip_lr)
    >> dt.FlipUD(raw_sources.flip_ud)
    >> dt.MoveAxis(2, 0)
    >> dt.pytorch.ToTensor(dtype=torch.float)
)
BenjaminMidtvedt commented 8 months ago

For dt.Crop, nice catch, I will fix it

BenjaminMidtvedt commented 8 months ago

@cmanzo the issue with crops should be fixed. Can you confirm?

cmanzo commented 8 months ago

@BenjaminMidtvedt Crop works but if applied to a joined pipeline, it applies different cropping to the image pair. Is there an easy fix?

BenjaminMidtvedt commented 7 months ago

Should be fixed now.