halleewong / ScribblePrompt

[ECCV 2024] ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Medical Image
http://scribbleprompt.csail.mit.edu/
Apache License 2.0
121 stars 11 forks source link

Instantiating ScribblePrompt-UNet #9

Open chenyuanjiao342 opened 1 month ago

chenyuanjiao342 commented 1 month ago

Hi, when I try to instantiate ScribblePrompt-UNet and make predictions, the segmentation results I get are not accurate. Part of my code is as follows. Is there anything wrong? Thank you very much!

def create_scribble_tensor(positive_coords, negative_coords, H, W):
    scribbles = torch.zeros((1, 2, H, W), dtype=torch.float32)
    for coord in positive_coords:
        x, y = coord
        if 0 <= x < W and 0 <= y < H:
            scribbles[0, 0, y, x] = 1

    for coord in negative_coords:
        x, y = coord
        if 0 <= x < W and 0 <= y < H:
            scribbles[0, 1, y, x] = 1
    return scribbles

def binary_mask_to_polygon(binary_mask, tolerance=0):
    padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
    contours = measure.find_contours(padded_binary_mask, 0.5)
    return contours

def parse_coords(coords_str):
    coords = list(map(int, coords_str.split(',')))
    coords_list = [coords[i:i+2] for i in range(0, len(coords), 2)]
    return coords_list

def main(args: argparse.Namespace) -> None:
    targets = [args.input_dir]

    for index, t in enumerate(targets):
        name = os.path.basename(t)
        image = pydicom.dcmread(t)
        image_w = image.Rows
        image_h = image.Columns
        image = image.pixel_array
        image = (image - image.min()) / (image.max() - image.min())
        image = np.clip(image, 0, 1)
        image = torch.tensor(image, dtype=torch.float32)
        image = image.unsqueeze(0).unsqueeze(0)
        image = image.permute(0, 1, 3, 2)
        image = F.interpolate(image, size=(128,128), mode='bilinear')

        pos_coords = parse_coords(args.pos_box)
        neg_coords = parse_coords(args.neg_box)
        positive_coords = [tuple(point) for point in pos_coords]
        negative_coords = [tuple(point) for point in neg_coords]
        scribbles = create_scribble_tensor(positive_coords, negative_coords, image_h, image_w)
        scribbles = F.interpolate(scribbles, size=(128,128), mode='bilinear') 

        sp_unet = ScribblePromptUNet()
        mask = sp_unet.predict(image, None, None, scribbles, None, None)
        mask = F.interpolate(mask, size=(image_h, image_w), mode='bilinear').squeeze()   
        mask = mask.cpu().numpy()     

        binary_mask = (mask > 0.5).astype(int)
        contours = binary_mask_to_polygon(binary_mask)
halleewong commented 1 month ago

Can you show examples of the input image and coordinates you're using and the predictions you get?

chenyuanjiao342 commented 2 weeks ago

Can you show examples of the input image and coordinates you're using and the predictions you get?

The first image is the input image and scribbles, and the second image is the predicted result. image image

halleewong commented 2 weeks ago

hmm I am not quite sure what's going on. Here are some ideas for troubleshooting: