FZJ-INM1-BDA / celldetection

Scalable Instance Segmentation using PyTorch & PyTorch Lightning.
https://docs.celldetection.org
Apache License 2.0
120 stars 20 forks source link

Issue setting up data for training #4

Closed LeoMira-1999 closed 2 years ago

LeoMira-1999 commented 2 years ago

Hi @ericup!

I would like to compliment first of all your repo, nicely done !

It's been a week now that I found out about your repo and been trying to load my cell library images. I've been using your "Cell Detection with Contour Proposal Networks" Notebook as a reference, plus the two other notebook in the demos folder.

The images I have are from dsb2018 and i have two folders of train and test images of 256x256 with no metadata folder and inside both folders are two other folders, images and masks, having respectively the tiff images and masks. How would you proceed?

If you have any idea or approach i'll be glad to hear from you.

Whilst I wait for your response,

Leonardo

ericup commented 2 years ago

Hi Leonardo,

Thank you very much!

The data structure sounds a little different from what you would get from the Broad Institute's website. But you have probably something like this right?

- train/
-- images/
--- img0.tiff
--- img1.tiff
    ...
-- masks/
--- img0_mask0.tiff
--- img0_mask1.tiff
--- img1_mask0.tiff
--- img1_mask1.tiff
--- img1_mask2.tiff
    ...

If you start with the Cell Detection with Contour Proposal Networks demo Notebook, you should replace these lines:

train_bbbc039 = cd.data.BBBC039Train(conf.directory, download=conf.download_data)
val_bbbc039 = cd.data.BBBC039Val(conf.directory)
test_bbbc039 = cd.data.BBBC039Test(conf.directory)

with some code that loads your data.

Given the structure you described it could look something like this:

import celldetection as cd
from glob import glob
from imageio import imread
from os.path import basename, dirname, join
import numpy as np

class DataFolder:
    def __init__(self, root):
        self.images, self.masks, self.labels = [], [], []
        self.names = glob(join(root, 'images', '*.tiff'))  # find image names
        for img_name in self.names:
            mask_names = glob(join(dirname(dirname(img_name)), 'masks',  # find masks for image
                                   basename(img_name).replace('.tiff', '*.tiff')))
            self.images.append(imread(img_name))  # read image
            masks = np.stack([imread(mask_name) for mask_name in mask_names], 0)  # read masks
            labels = cd.data.unary_masks2labels(masks).max(-1, keepdims=True)  # convert masks to label image
            self.masks.append(masks), self.labels.append(labels)  # append to list

    def __len__(self):
        return len(self.images)

    def __getitem__(self, item):
        return self.names[item], self.images[item], self.masks[item], self.labels[item]

train_dsb2018 = DataFolder('./train')
test_dsb2018 = DataFolder('./test')
# name, img, masks, labels = train_dsb2018[0]

This example reads all your data from the train or test folder and converts the masks for each image to label images. Note that this code assumes that each mask contains exactly one object, which should be the case for dsb2018. If you have just one mask with many objects per image you can have a look at cd.data.masks2labels to replace cd.data.unary_masks2labels. If you are using RGB images you should also change the number of channels in the Config from in_channels=1 to in_channels=3. Apart from that you might want to try different data augmentation and normalization.

I hope this helps and gets you started!

Best regards, Eric

LeoMira-1999 commented 2 years ago

Thanks for your amazingly fast response !

It did work with your help, my files are .tif, that's the only thing i changed.

I'm having trouble now with the visualization of the data using:

def demo_transforms(img, lbl, name=None):
    img = cd.data.normalize_percentile(img, percentile=99.8)
    lbl = lbl.copy()

    # Show original
    cd.vis.show_detection(img, contours=cd.data.labels2contours(lbl), contour_linestyle='-')
    if name is not None:
        plt.title(name)
    plt.show()

    # Show transformed
    s = 3
    plt.figure(None, (s * 9, s * 9))
    for i in range(1, s * s + 1):
        plt.subplot(s, s, i)
        trans = transforms(image=img, mask=lbl.astype('int32'))
        t_img, t_lbl = trans['image'], trans['mask']
        cd.data.relabel_(t_lbl)
        plt.title(f'transformed {t_img.dtype.name, t_img.shape, t_lbl.dtype.name, t_lbl.shape}')
        cd.vis.show_detection(t_img, contours=cd.data.labels2contours(t_lbl), contour_linestyle='-')
    plt.show()

And i guess this function from your notebook shouldn’t be changed ?:

class Data:
    def __init__(self, data, config, transforms=None, items=None, size=None):
        self.transforms = transforms
        self.gen = cd.data.CPNTargetGenerator(
            samples=config.samples,
            order=config.order,
            max_bg_dist=config.bg_fg_dists[0],
            min_fg_dist=config.bg_fg_dists[1],
        )
        self._items = items or len(data)
        self.data = data
        self.size = size
        self.channels = config.in_channels

    def __len__(self):
        return self._items

    @staticmethod
    def map(image):
        image = image / 127.5
        image -= 1
        if image.ndim == 2:
            image = image[..., None]
        return image.astype('float32')

    @staticmethod
    def unmap(image):
        image = (image + 1) * 127.5
        image = np.clip(image, 0, 255).astype('uint8')
        if image.ndim == 3 and image.shape[2] == 1:
            image = np.squeeze(image, 2)
        return image

    def __getitem__(self, item):
        if item >= len(self):
            raise IndexError('Index out of bounds.')
        item = item % len(self.data)

        # Get image and labels
        name, img, _, labels = self.data[item]

        # Normalize intensities
        img, labels = np.copy(img).squeeze(), np.copy(labels)
        img = cd.data.normalize_percentile(img, percentile=99.8)
        labels = labels.astype('int32')       

        # Optionally crop
        if self.size is not None:
            h, w = self.size
            img, labels = cd.data.random_crop(img, labels, height=h, width=w)

        # Optionally transform
        if self.transforms is not None:
            r = self.transforms(image=img, mask=labels)
            img, labels = r['image'], r['mask']

        # Ensure channels exist
        if labels.ndim == 2:
            labels = labels[..., None]

        # Relabel to ensure that N objects are marked with integers 1..N
        cd.data.relabel_(labels)

        # Feed labels to target generator
        gen = self.gen
        gen.feed(labels=labels)

        # Map image to range -1..1
        image = self.map(img)

        # Return as dictionary
        return OrderedDict({
            'inputs': image,
            'labels': gen.reduced_labels,
            'fourier': (gen.fourier.astype('float32'),),
            'locations': (gen.locations.astype('float32'),),
            'sampled_contours': (gen.sampled_contours.astype('float32'),),
            'sampling': (gen.sampling.astype('float32'),),
            'targets': labels
        })

However, Order plot and Plot Exemple work like a charm to visualize the data, so maybe i don't need the previous visualization, this works fine.

def plot_example(data_loader, figsize=(8, 4.5)):
    # Pick example
    example = cd.asnumpy(next(iter(data_loader)))
    contours = example['sampled_contours'][0]
    image = cd.data.channels_first2channels_last(example['inputs'][0])
    # Plot example
    cd.vis.show_detection(image, contours=contours, figsize=figsize, contour_linestyle='-',
                          cmap='gray' if image.shape[2] == 1 else ...)
    plt.colorbar()
    plt.tight_layout()
    plt.ylim([0, image.shape[0]])
    plt.xlim([0, image.shape[1]])
def order_plot(data, data_loader):
    # Plot example data for different `order` settings
    gen = data.gen
    s = int(np.ceil(np.sqrt(conf.order)))
    plt.figure(None, (s * 12, s * 6.75))
    for gen.order in range(1, conf.order + 1):
        plt.subplot(s, s, gen.order)
        plot_example(data_loader, figsize=None)
        plt.title(f'{"Chosen o" if gen.order == conf.order else "O"}rder: {gen.order}')
    plt.show()
    assert gen.order == conf.order

I went ahead to see appart from the visualization, if the model would train, and now it's in the process of doing so, i'll see if the results are coherent.

EDIT: An extra question that I have, when you finished training your model, is it saved somewhere?

Thanks again

Leonardo

ericup commented 2 years ago

What kind of trouble do you have with the demo_transforms function? I think you do not need to change the Data class, but you could for example change the map and unmap functions to normalize differently. Also, you might not need cd.data.normalize_percentile for this dataset.

The demo currently does not save your model, but you can of course save it as usual.

LeoMira-1999 commented 2 years ago

The transformation works fine, i couldn't display the transformed images though but I won't focus on that anymore.

Indeed, i saved the model the regular torch way. Is there a way i can change the default Fourier Descriptor in the config file or anywhere else to test it in a range(6,18,4)? My guess would be the 'order' param but maybe there's more to it ?

Thanks a lot for all the help you provided!

ericup commented 2 years ago

Changing the order setting in the Config should be sufficient. During test time of a trained model you can also change model.order to a smaller setting. The model then just omits part of the descriptor and outputs results according to the given order.

If you want to try different configs you might find this helpful.

LeoMira-1999 commented 2 years ago

Thanks for all the tips again and great work !