FZJ-INM1-BDA / celldetection

Scalable Instance Segmentation using PyTorch & PyTorch Lightning.
https://docs.celldetection.org
Apache License 2.0
121 stars 20 forks source link

Can i get the contour from (x,y) points or a bounding box? #16

Open kpbhat25 opened 4 months ago

kpbhat25 commented 4 months ago

Where can i feed in the points or the bounding box co-ords to the model so that I get a contour? Can you please tell me

ericup commented 4 months ago

Great question! It's not implemented in the package yet, but I'll add it to the list of things to include in the next update round of the cpn code.

For now you could test this:

import celldetection as cd
import numpy as np
import torch
import torchvision.ops.boxes as bx
from skimage.filters import rank
from skimage.morphology import disk

def select_boxes_by_iou(box_proposals, box_targets):
    iou = bx.box_iou(box_proposals, box_targets)
    indices = torch.argmax(iou, dim=0)  # double assignments possible
    return indices

def points2point_query(points, size, radius=None):
    point_query = np.zeros(size, dtype='uint8')
    point_query[points[:, 1], points[:, 0]] = 1
    if radius is not None:
        point_query = rank.maximum(point_query, footprint=disk(radius))
    return point_query

def boxes2box_query(boxes, size):
    return np.max(cd.data.boxes2masks(boxes, size), 0).clip(0, 1)

def points2contours(model, img, points):
    assert img.ndim == 3, 'Convert to RGB first'
    size = img.shape[:2]
    device = cd.get_device(model)
    x = cd.to_tensor(img, transpose=True, device=device, dtype=torch.float32)[None] / 255
    point_query = points2point_query(points, size, radius=4)
    q_points = cd.to_tensor(point_query > 0, device=device, dtype=torch.float32)[None, None]
    with torch.no_grad():
        y_from_points = model(x, scores_upper_bound=q_points, scores_lower_bound=q_points)
    return y_from_points['contours'][0]

def boxes2contours(model, img, boxes, points=None):
    assert img.ndim == 3, 'Convert to RGB first'
    size = img.shape[:2]
    device = cd.get_device(model)
    x = cd.to_tensor(img, transpose=True, device=device, dtype=torch.float32)[None] / 255

    # Prepare queries
    box_query = boxes2box_query(boxes, size)
    q_boxes = cd.to_tensor(box_query > 0, device=device, dtype=torch.float32)[None, None]

    # Run model
    with torch.no_grad():
        if points is None:
            # Box query: select whole box area (slower)
            y_from_boxes = model(x, scores_upper_bound=q_boxes, scores_lower_bound=q_boxes, nms=False)
        else:
            # Box query via point query (faster)
            point_query = points2point_query(points, size, radius=4)
            q_points = cd.to_tensor(point_query > 0, device=device, dtype=torch.float32)[None, None]
            y_from_boxes = model(x, scores_upper_bound=q_points, scores_lower_bound=q_points, nms=False)

    # Filter contours using provided target boxes
    box_proposals, = y_from_boxes['boxes']
    keep = select_boxes_by_iou(box_proposals, torch.as_tensor(boxes))
    return y_from_boxes['contours'][0][keep]  # only keep matching contours

For the models you should use CPNs. Points should be Array[n, 2] and boxes Array[n, 4].

Let me know if this helps or you have any questions!

kpbhat25 commented 4 months ago

This is helpful. Thank you I tried to replicate Nuclick (https://github.com/mostafajahanifar/nuclick_torch), a click-based annotation, as your paper provides insight on better contours.

ericup commented 4 months ago

That's a really cool idea! Feel free to keep me posted! Also let me know if you have any questions