JoHof / lungmask

Automated lung segmentation in CT
Apache License 2.0
669 stars 150 forks source link

Possible Performance Improvement for Postrocessing #65

Closed william-silversmith closed 1 year ago

william-silversmith commented 2 years ago

Hi JoHof! I was just googling around to see how people were using my libraries and I saw that the postrocessing function could be made a lot faster. I saw many people have used this work as a template (congrats!), so hopefully many different projects can benefit from a speedup. It's possible this is not a bottleneck in your code, in which case please ignore this issue, as I don't want to waste your time.

https://github.com/JoHof/lungmask/blob/master/lungmask/utils.py#L194-L250

Here's some example improvements using the fastremap, and cc3d libraries. I am not certain I understood every aspect of your algorithm and it hasn't been tested, so view this as just a guide.

I'm a little confused by why you need to extract the largest component after fusing neighboring components by size, but I'm probably missing something. However, if you can skip that step, this code will be blazingly fast. I added a renumbering step at the end so the labels that come out will be numbered from 1.

def postrocessing(label_image, spare=[]):
    '''some post-processing mapping small label patches to the neighbout whith which they share the
        largest border. All connected components smaller than min_area will be removed
    '''

    # merge small components to neighbours
    regionmask, N = cc3d.connected_components(label_image, return_N=True)
    stats = cc3d.statistics(regionmask)
    volumes = stats['voxel_counts']
    edges = cc3d.region_graph(regionmask, connectivity=connectivity) 

    remap = {}
    for label in range(N+1):
        neighbours = set([ label ])
        for edge in edges:
            if label not in edge:
                continue
            neighbours.add(edge[0])
            neighbours.add(edge[1])

        neighbours = list(neighbours)
        biggest_label = neighbors[ 
            np.argmax([ volumes[lbl] for lbl in neighbours ])
        ]
        remap[label] = biggest_label

    for val in spare:
        remap[val] = val

    fastremap.remap(regionmask, remap, in_place=True)

    if regionmask.shape[0] == 1:
        # holefiller = lambda x: ndimage.morphology.binary_fill_holes(x[0])[None, :, :] # This is bad for slices that show the liver
        holefiller = lambda x: skimage.morphology.area_closing(x[0].astype(int), area_threshold=64)[None, :, :] == 1
    else:
        holefiller = fill_voids.fill

    outmask = np.zeros(regionmask.shape, dtype=np.uint8)
    for label, mask in cc3d.each(regionmask):
        mask = cc3d.largest_k(mask, k=1)
        outmask[holefiller(mask)] = label

    fastremap.renumber(outmask, in_place=True)
    return outmask

I hope this is helpful. Thanks for reading!

william-silversmith commented 2 years ago

I realized that this doesn't account for contact surfaces instead substituting in total voxel count. However, it was an easy update to cc3d to add in contact surface area or contact voxel count calculation via cc3d.contacts in version 3.10.0. This is a generally useful feature for lots of people, so don't worry about it being too specific.


def postrocessing(label_image, spare=[]):
    '''some post-processing mapping small label patches to the neighbout whith which they share the
        largest border. All connected components smaller than min_area will be removed
    '''

    # merge small components to neighbours
    regionmask, N = cc3d.connected_components(label_image, return_N=True)
    edges = cc3d.contacts(regionmask, connectivity=connectivity, surface_area=False) # contact voxel count

    remap = {}
    for label in range(N+1):
        contacts = []
        for edge in edges:
            if label not in edge:
                continue
            contacts.append((edge, edges[edge] ))

        biggest_label = sorted(contacts, key=lambda x: x[1], reverse=True)[0]
        biggest_label = biggest_label[0] if biggest_label[1] == label else biggest_label[1]
        remap[label] = biggest_label

    for val in spare:
        remap[val] = val

    fastremap.remap(regionmask, remap, in_place=True)

    if regionmask.shape[0] == 1:
        # holefiller = lambda x: ndimage.morphology.binary_fill_holes(x[0])[None, :, :] # This is bad for slices that show the liver
        holefiller = lambda x: skimage.morphology.area_closing(x[0].astype(int), area_threshold=64)[None, :, :] == 1
    else:
        holefiller = fill_voids.fill

    outmask = np.zeros(regionmask.shape, dtype=np.uint8)
    for label, mask in cc3d.each(regionmask):
        mask = cc3d.largest_k(mask, k=1)
        outmask[holefiller(mask)] = label

    fastremap.renumber(outmask, in_place=True)
    return outmask
JoHof commented 2 years ago

Hi William,

thanks for your suggestions, they are highly appreciated. Using you package would make the code indeed much better. I have planned to rewrite this package for a while now because I am not happy with it's implementation but I simply can't find time. Your suggestions would be a good start. However, your algorithm above does not do the exact same thing as the current code. It doesn't guarantee that only labels in the spare set will remain in the final label image. The current implementation will interactively merge smaller to larger neighbor areas while your code maps only once. There may also be a bug... This for val in spare: remap[val] = val would more or less randomly ignore mappings because the values in spare refer to the original labels while the remap dict refers to the labels assigned during cc3d, right?

william-silversmith commented 2 years ago

Yes, thank you for clarifying that. I wasn't entirely sure what role spare was playing as there were some elements that were confusing to me.

~For example, spare is a list, but it is sometimes treated like a scalar. Could this be a bug? and n != spare n is an integer obtained from iterating over uniques but spare is a list. According to a test in my python terminal, that would always evaluate to True:~

>>> [1] != 1
True

Later in the algorithm it looked like spare labels were being eliminated from the mask:

outmask_mapped[outmask_mapped==spare] = 0 
...

outmask = np.zeros(outmask_mapped.shape, dtype=np.uint8)
for i in np.unique(outmask_mapped)[1:]:
    outmask[holefiller(keep_largest_connected_component(outmask_mapped == i))] = i

return outmask

I may be misreading the code here, but I think the number of loops is the same for the merging step. The second loop only looks for the biggest map which corresponds with my sort step.

    for r in tqdm(regions):
            ... 
            for ix, n in enumerate(neighbours):
                if n != 0 and n != r.label and counts[ix] > maxmap and n != spare:
                    maxmap = counts[ix]
                    mapto = n
                    myarea = r.area
            regionmask[regionmask == r.label] = mapto
            ...

I've not run the code, so I may be far off the mark. However, adjusting my code to your above description of the algorithm, would this be closer to what you are thinking?

def postrocessing(label_image, spare=[]):
    '''some post-processing mapping small label patches to the neighbout whith which they share the
        largest border. All connected components smaller than min_area will be removed
    '''

    # merge small components to neighbours
    regionmask, N = cc3d.connected_components(label_image, return_N=True)

    component_map = fastremap.component_map(label_image, regionmask)
    component_map = { k:v for k,v in component_map.items() if k in spare }

    while True:
        edges = cc3d.contacts(regionmask, connectivity=connectivity, surface_area=False) # contact voxel count

        if len(edges) == 0:
            break

        remap = {}
        for label in range(N+1):
            contacts = []
            for edge in edges:
                if label not in edge:
                    continue
                contacts.append((edge, edges[edge] ))

            if len(contacts) == 0:
                continue

            biggest_label = sorted(contacts, key=lambda x: x[1], reverse=True)[0]
            biggest_label = biggest_label[0] if biggest_label[1] == label else biggest_label[1]
            remap[label] = biggest_label

    if regionmask.shape[0] == 1:
        # holefiller = lambda x: ndimage.morphology.binary_fill_holes(x[0])[None, :, :] # This is bad for slices that show the liver
        holefiller = lambda x: skimage.morphology.area_closing(x[0].astype(int), area_threshold=64)[None, :, :] == 1
    else:
        holefiller = fill_voids.fill

    outmask = np.zeros(regionmask.shape, dtype=np.uint8)
    for label, mask in cc3d.each(regionmask):
        mask = cc3d.largest_k(mask, k=1)
        outmask[holefiller(mask)] = label

    fastremap.mask_except(outmask, list(component_map.values()), in_place=True)
    fastremap.renumber(outmask, in_place=True)
    return outmask
william-silversmith commented 2 years ago

Actually, strike that first comment, it looks like if the integer is a numpy integer, that n != spare works. I think that's deprecated behavior though. The updated way to write that would be np.any(n != spare) because otherwise it's not clear to the casual reader what the truthiness of if np.array([ False ]): is (it's False in my test).