michaellee1 / CellSeg

MIT License
21 stars 25 forks source link

Mask growth breaks for edge cases #66

Open MeyerBender opened 7 months ago

MeyerBender commented 7 months ago

Hi,

while investigating the mask growing method, I have come across some unexpected behavior, which looks incorrect to me. For example, notice how the mask on the left side of the image occupies pixels that overlap with other cells from the original segmentation.

Original image: image

Image grown by 1px: image

I have extracted the corresponding code snippets from the CVMask class to create this standalone example for testing:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.morphology import disk, dilation
from scipy.ndimage.morphology import binary_dilation
from sklearn.neighbors import kneighbors_graph
from scipy.spatial.distance import cdist

# adapted from CVMask
def compute_centroids(flatmasks):
    masks = flatmasks
    num_masks = len(np.unique(masks)) - 1
    indices = np.where(masks != 0)
    values = masks[indices[0], indices[1]]

    maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
    centroids = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_records(index = False).tolist()

    return centroids

# adapted from CVMask
def remove_overlaps_nearest_neighbors(centroids, masks):
        final_masks = np.max(masks, axis = 2)
        collisions = np.nonzero(np.sum(masks > 0, axis = 2) > 1)
        collision_masks = masks[collisions]
        collision_index = np.nonzero(collision_masks)
        collision_masks = collision_masks[collision_index]
        collision_frame = pd.DataFrame(np.transpose(np.array([collision_index[0], collision_masks]))).rename(columns = {0:"collis_idx", 1:"mask_id"})
        grouped_frame = collision_frame.groupby('collis_idx')
        for collis_idx, group in grouped_frame:
            collis_pos = np.expand_dims(np.array([collisions[0][collis_idx], collisions[1][collis_idx]]), axis = 0)
            prevval = final_masks[collis_pos[0,0], collis_pos[0,1]]
            mask_ids = list(group['mask_id'])
            curr_centroids = np.array([centroids[mask_id - 1] for mask_id in mask_ids])
            dists = cdist(curr_centroids, collis_pos)
            closest_mask = mask_ids[np.argmin(dists)]
            final_masks[collis_pos[0,0], collis_pos[0,1]] = closest_mask

        return final_masks

# adapted from CVMask
def grow_masks(flatmasks, centroids, growth, method = 'Standard', num_neighbors = 30):
    masks = flatmasks
    num_masks = len(np.unique(masks)) - 1

    # only looking at the standard method, but sequential also appears to have some issues
    if method == 'Standard':
        print("Standard growth selected")
        masks = flatmasks
        num_masks = len(np.unique(masks)) - 1
        indices = np.where(masks != 0)
        values = masks[indices[0], indices[1]]

        maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
        cent_array = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_numpy()
        connectivity_matrix = kneighbors_graph(cent_array, num_neighbors).toarray() * np.arange(1, num_masks + 1)
        connectivity_matrix = connectivity_matrix.astype(int)
        labels = {}
        for n in range(num_masks):
            connections = list(connectivity_matrix[n, :])
            connections.remove(0)
            layers_used = [labels[i] for i in connections if i in labels]
            layers_used.sort()
            currlayer = 0
            for layer in layers_used:
                if currlayer != layer: 
                    break
                currlayer += 1
            labels[n + 1] = currlayer

        possible_layers = len(list(set(labels.values())))
        label_frame = pd.DataFrame(list(labels.items()), columns = ["maskid", "layer"])
        image_h, image_w = masks.shape
        expanded_masks = np.zeros((image_h, image_w, possible_layers), dtype = np.uint32)

        grouped_frame = label_frame.groupby('layer')
        for layer, group in grouped_frame:
            currids = list(group['maskid'])
            masklocs = np.isin(masks, currids)
            expanded_masks[masklocs, layer] = masks[masklocs]

        dilation_mask = disk(1)
        grown_masks = np.copy(expanded_masks)
        for _ in range(growth):
            for i in range(possible_layers):
                grown_masks[:, :, i] = dilation(grown_masks[:, :, i], dilation_mask)
        return remove_overlaps_nearest_neighbors(centroids, grown_masks)

example_data = np.array([[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 0, 0, 0, 0, 1, 1, 1, 1],
       [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 0, 0, 0, 1, 1, 1, 1, 1],
       [6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1, 1],
       [6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 1, 1, 1, 1, 1],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 1, 1, 1, 1, 4],
       [2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 9, 9, 4, 4],
       [2, 2, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 9, 9, 9, 9, 9],
       [2, 0, 0, 7, 7, 7, 7, 7, 7, 7, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9],
       [0, 0, 0, 0, 7, 7, 7, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9],
       [3, 0, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9],
       [3, 3, 0, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9],
       [3, 3, 3, 0, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9],
       [3, 3, 8, 8, 8, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9]])

centroids = compute_centroids(example_data)
masks_grown = grow_masks(example_data, centroids, 1, method = 'Standard', num_neighbors = 8)
plt.imshow(example_data)
plt.show()
plt.imshow(masks_grown)
plt.show()

I would highly appreciate if you could tell me if I am using this method wrong, or if this is actually a bug within the method. Thank you very much in advance!

MeyerBender commented 7 months ago

There were two issues with the current methodology which I spotted:

  1. The remove_overlaps method only looked at the grown masks, but not the original ones. This led to unexpected side effects and sometimes even the removal of certain cells (if they were completely engulfed by a grown mask with a higher index).
  2. The remove_overlaps method should be called after each iteration to avoid pixels being dissociated from the parent masks.

Applying these changes to the example above, I get more sensible results. image

Here is the code I used:

def compute_centroids(flatmasks):
    masks = flatmasks
    num_masks = len(np.unique(masks)) - 1
    indices = np.where(masks != 0)
    values = masks[indices[0], indices[1]]

    maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
    centroids = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_records(index = False).tolist()

    return centroids

def remove_overlaps_nearest_neighbors(original_masks, masks, centroids):
    final_masks = np.max(masks, axis=2)
    collisions = np.nonzero(np.sum(masks > 0, axis=2) > 1)
    collision_masks = masks[collisions]
    collision_index = np.nonzero(collision_masks)
    collision_masks = collision_masks[collision_index]
    collision_frame = pd.DataFrame(np.transpose(np.array([collision_index[0], collision_masks]))).rename(
        columns={0: "collis_idx", 1: "mask_id"}
    )
    grouped_frame = collision_frame.groupby("collis_idx")
    for collis_idx, group in grouped_frame:
        collis_pos = np.expand_dims(np.array([collisions[0][collis_idx], collisions[1][collis_idx]]), axis=0)
        # ALTERED: THIS USED TO ONLY REFER TO THE GROWN MASKS INSTEAD OF THE ORIGINAL ONES, WHICH LED TO UNEXPECTED BEHAVIOR
        mask_ids = list(group["mask_id"])
        curr_centroids = np.array([centroids[mask_id - 1] for mask_id in mask_ids])
        dists = cdist(curr_centroids, collis_pos)
        closest_mask = mask_ids[np.argmin(dists)]
        final_masks[collis_pos[0, 0], collis_pos[0, 1]] = closest_mask

    # ALTERED
    # setting all values to the original masks so no masks get overwritten
    # we need: an inverted binary array telling us where there was originally background (in original_masks)
    # multiply this with the final masks and add it to the original masks
    background_pixels = original_masks == 0
    # only reassigning cells which were previously background
    final_masks = np.array(final_masks * background_pixels, dtype=original_masks.dtype)
    # adding this growth to the original masks
    final_masks += original_masks
    return final_masks

def grow_masks(flatmasks, centroids, growth, num_neighbors = 30):
    masks = flatmasks
    num_masks = len(np.unique(masks)) - 1
    num_neighbors = min(num_neighbors, num_masks-1)

    # ALTERED: OVERLAPS GET REMOVED AFTER EACH ITERATION TO AVOID PIXELS BEING DISSOCIATED FROM THEIR ORIGINAL MASK
    for _ in range(growth):
        # getting neighboring cells
        indices = np.where(masks != 0)
        values = masks[indices[0], indices[1]]
        maskframe = pd.DataFrame(np.transpose(np.array([indices[0], indices[1], values]))).rename(columns = {0:"x", 1:"y", 2:"id"})
        cent_array = maskframe.groupby('id').agg({'x': 'mean', 'y': 'mean'}).to_numpy()
        connectivity_matrix = kneighbors_graph(cent_array, num_neighbors).toarray() * np.arange(1, num_masks + 1)
        connectivity_matrix = connectivity_matrix.astype(int)
        labels = {}
        for n in range(num_masks):
            connections = list(connectivity_matrix[n, :])
            connections.remove(0)
            layers_used = [labels[i] for i in connections if i in labels]
            layers_used.sort()
            currlayer = 0
            for layer in layers_used:
                if currlayer != layer: 
                    break
                currlayer += 1
            labels[n + 1] = currlayer

        possible_layers = len(list(set(labels.values())))
        label_frame = pd.DataFrame(list(labels.items()), columns = ["maskid", "layer"])
        image_h, image_w = masks.shape
        expanded_masks = np.zeros((image_h, image_w, possible_layers), dtype = np.uint32)

        grouped_frame = label_frame.groupby('layer')
        for layer, group in grouped_frame:
            currids = list(group['maskid'])
            masklocs = np.isin(masks, currids)
            expanded_masks[masklocs, layer] = masks[masklocs]

        dilation_mask = disk(1)
        grown_masks = np.copy(expanded_masks)
        for i in range(possible_layers):
            grown_masks[:, :, i] = dilation(grown_masks[:, :, i], dilation_mask)
        masks = remove_overlaps_nearest_neighbors(masks, grown_masks, centroids)

    return masks

In my tests, this altered version now performed as I expected it to. Of course you should test it on some of your own examples, but I believe that these changes fix the mask growing (at least the Standard method) and you might want to consider implementing them into the CellSeg codebase.