TIO-IKIM / CellViT

CellViT: Vision Transformers for Precise Cell Segmentation and Classification
https://doi.org/10.1016/j.media.2024.103143
Other
189 stars 27 forks source link

How to get model prediction for a single pannuke image? #37

Closed neoyinzhanghan closed 4 months ago

neoyinzhanghan commented 5 months ago

Hi!

I've been trying to use your model checkpoint to get the instance segmentation result on just a single pannuke jpg image. I cannot seem to make out what exactly I need to do to perform such a simple task, while it seems like the package directly supports much more complex tasks such as getting instance segmentation on the entire WSI.

Unable to find a function/method in the package that does this directly, I attempted a hacky approach and wrote the following function:

from PIL import Image
from torchvision import transforms as T
from cell_segmentation.inference.cell_detection import CellSegmentationInference

def run_one_image(inf_class, image_path):
    # Load the image
    image = Image.open(image_path).convert("RGB")

    # Define the transformation
    transform = T.Compose(
        [
            T.Resize((224, 224)),  # Resize the image to 224x224
            T.ToTensor(),  # Convert the image to a tensor
        ]
    )

    # Apply the transformation to the image
    batch = transform(image).unsqueeze(0)  # Unsqueeze to add the batch dimension
    # tensor_image now has shape (1, 3, 224, 224)

    inf_class.logger.info("Loading inference transformations")

    transform_settings = inf_class.run_conf["transformations"]
    if "normalize" in transform_settings:
        mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5))
        std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5))
    else:
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)
    inf_class.inference_transforms = T.Compose(
        [T.ToTensor(), T.Normalize(mean=mean, std=std)]
    )

    # patches = batch[0].to(inf_class.device)
    batch = batch.to(inf_class.device)

    # print the dimensions of the patches
    # print("Patches shape: {}".format(patches.shape))
    # print("Batch shape: {}".format(batch.shape))
    inf_class.logger.info("Patches shape: {}".format(batch.shape))

    predictions = inf_class.model.forward(batch, retrieve_tokens=True)

    instance_types, tokens = inf_class.get_cell_predictions_with_tokens(
                predictions, magnification=40
            )

        return predictions

if __name__ == "__main__":
    image_path = "/media/hdd1/pannuke/images/train/image_0035.jpg"

    inf_class = CellSegmentationInference(
        model_path="/home/alpaca/Documents/neo/CellViT/CellViT-256-x40.pth", gpu=0
    )

    prediction = run_one_image(inf_class, image_path)

However, it does not seem like the predictions dictionary, nor instance_types, nor tokens would be the desired output. All the tensors in predictions have the same shapes for all the different images I used (I'd imagine if things are running correctly the result would have different shapes for different images because there would be different number of objects of interest) tissue_types torch.Size([1, 19]) nuclei_binary_map torch.Size([1, 2, 224, 224]) hv_map torch.Size([1, 2, 224, 224]) nuclei_type_map torch.Size([1, 6, 224, 224]) tokens torch.Size([1, 384, 14, 14]) <class 'list'>

The instance_types object is a list of a single element which appears to be an empty dictionary.

Would really appreciate your help on pointing out what is wrong in my implementation, or a pointer on where to look if I just want to run the model to get the segmentation on a single image. Thank you!

neoyinzhanghan commented 5 months ago

Update. I've tried a different approach. This I believe is looking in the right place. It is still very hacky. Unfortunately, the number of instances detected by this script seems to always be 0.

import os
import shutil
import matplotlib.pyplot as plt
import random
import albumentations as A
from PIL import Image
from torchvision import transforms as T
from cell_segmentation.inference.inference_cellvit_experiment_pannuke import InferenceCellViT, InferenceCellViTParser

# Function to generate a random color
def random_color():
    return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

def run_one_image(inf_class, image_path):
    # Load the image
    image = Image.open(image_path).convert("RGB")
    transform_settings = inf_class.run_conf["transformations"]

    if "normalize" in transform_settings:
        mean = transform_settings["normalize"].get("mean", (0.5, 0.5, 0.5))
        std = transform_settings["normalize"].get("std", (0.5, 0.5, 0.5))
    else:
        mean = (0.5, 0.5, 0.5)
        std = (0.5, 0.5, 0.5)

    transforms = T.Compose([T.ToTensor(), T.Normalize(mean=mean, std=std)])

    # Apply the transformation to the image
    batch = transforms(image).unsqueeze(0)  # Unsqueeze to add the batch dimension

    inf_class.logger.info("Loading inference transformations")

    # patches = batch[0].to(inf_class.device)
    batch = batch.to(inf_class.device)

    # print the dimensions of the patches
    # print("Patches shape: {}".format(patches.shape))
    # print("Batch shape: {}".format(batch.shape))
    inf_class.logger.info("Patches shape: {}".format(batch.shape))

    model = inf_class.get_model(model_type='CellViT256')

    # move the model to the GPU
    model.to(inf_class.device)

    predictions = model.forward(batch, retrieve_tokens=True)

    predictions = inf_class.unpack_predictions(predictions=predictions, model=model)

    print(predictions.instance_map[0].shape)

    # Each instance has its own integer, starting from 1. Shape: (H, W)
    # save all the instances as a binary mask image in run_dir/results as ins_1, ins_2 ,...
    num_instances = int(predictions.instance_map[0].max())

    # print the number of instances
    print("Number of instances: {}".format(num_instances))

    # Create the results directory if it does not exist
    results_dir = os.path.join(inf_class.run_dir, "results")

    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    # Save the instance map as a binary mask
    for i in range(1, num_instances + 1):
        instance_mask = predictions.instance_map[0] == i
        instance_mask = instance_mask.cpu().numpy().astype("uint8") * 255
        instance_mask = Image.fromarray(instance_mask)
        instance_mask.save(os.path.join(results_dir, f"ins_{i}.png"))

    # cope the image_path to the results directory
    shutil.copy(image_path, os.path.join(results_dir, "image.png"))

    # Ensure num_instances is defined; for example, it might be the max label in your instance_map
    num_instances = int(predictions.instance_map.max())

    # Load the original image
    original_image = Image.open(image_path).convert("RGB")

    # Overlay each instance mask on the original image with a random color
    for i in range(1, num_instances + 1):
        # Load the binary mask
        instance_mask = predictions.instance_map[0] == i
        instance_mask = instance_mask.cpu().numpy().astype("uint8") * 255
        instance_mask = Image.fromarray(instance_mask).convert("L")  # Ensure it's greyscale

        # Generate a random color for the instance
        color = random_color()

        # Create an RGB image of the same size as the original image but filled with the random color
        colored_mask = Image.new("RGB", original_image.size, color=color)

        # Apply the binary mask as an alpha mask to overlay the colored mask onto the original image
        original_image.paste(colored_mask, (0,0), instance_mask)

    # Save or display the modified original image
    original_image.save(os.path.join(results_dir, "image_with_masks.png"))
    # or use original_image.show() to display the image
if __name__ == "__main__":

    image_path = "/media/hdd1/pannuke/images/train/image_0080.jpg"

    print("Running the inference on", image_path, "\n")

    configuration_parser = InferenceCellViTParser()
    configuration = configuration_parser.parse_arguments()
    print(configuration)
    inf_class = InferenceCellViT(
        run_dir=configuration["run_dir"],
        checkpoint_name=configuration["checkpoint_name"],
        gpu=0,
        magnification=40,
    )

    prediction = run_one_image(inf_class, image_path)

This is the print in the console:

(cellvit_env) alpaca@path-lambda1:~/Documents/neo/CellViT$ python run.py --run_dir ~/Documents/neo/CellViT/run_dir --checkpoint_name CellViT-256-x40.pth
2024-02-08 13:45:31.706321: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-02-08 13:45:32.281667: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Running the inference on /media/hdd1/pannuke/images/train/image_0040.jpg 

{'run_dir': '/home/alpaca/Documents/neo/CellViT/run_dir', 'checkpoint_name': 'CellViT-256-x40.pth', 'gpu': 5, 'magnification': 40, 'plots': False}
Loaded run: /home/alpaca/Documents/neo/CellViT/run_dir
Loading inference transformations
Patches shape: torch.Size([1, 3, 256, 256])
torch.Size([256, 256])
Number of instances: 0

My config.yaml file is as follow where I stole it from one of the pannuke config file and modified a few lines to match my own data and workstation:

logging:
  mode: online
  project: Cell-Segmentation
  notes: CellViT-SAM-B
  log_comment: CellViT-SAM-B-Fold-1
  tags:
  - Fold-1
  - SAM-B
  level: Debug
  group: CellViT-SAM-B
random_seed: 19
gpu: 0
data:
  dataset: PanNuke
  dataset_path: /home/alpaca/Documents/neo/CellViT/configs/datasets/PanNuke/
  train_folds:
  - fold0
  val_folds:
  - fold1
  test_folds:
  - fold2
  num_nuclei_classes: 6
  num_tissue_classes: 19
model:
  backbone: SAM-B
  pretrained_encoder: ./models/pretrained/SAM/sam_vit_b.pth
  shared_skip_connections: true
loss:
  nuclei_binary_map:
    focaltverskyloss:
      loss_fn: FocalTverskyLoss
      weight: 1
    dice:
      loss_fn: dice_loss
      weight: 1
  hv_map:
    mse:
      loss_fn: mse_loss_maps
      weight: 2.5
    msge:
      loss_fn: msge_loss_maps
      weight: 8
  nuclei_type_map:
    bce:
      loss_fn: xentropy_loss
      weight: 0.5
    dice:
      loss_fn: dice_loss
      weight: 0.2
    mcfocaltverskyloss:
      loss_fn: MCFocalTverskyLoss
      weight: 0.5
      args:
        num_classes: 6
  tissue_types:
    ce:
      loss_fn: CrossEntropyLoss
      weight: 0.1
training:
  drop_rate: 0
  attn_drop_rate: 0.1
  drop_path_rate: 0.1
  batch_size: 16
  epochs: 130
  optimizer: AdamW
  early_stopping_patience: 130
  scheduler:
    scheduler_type: exponential
    hyperparameters:
      gamma: 0.85
  optimizer_hyperparameter:
    betas:
    - 0.85
    - 0.95
    lr: 0.0003
    weight_decay: 0.0001
  unfreeze_epoch: 25
  sampling_gamma: 0.85
  sampling_strategy: cell+tissue
  mixed_precision: true
transformations:
  randomrotate90:
    p: 0.5
  horizontalflip:
    p: 0.5
  verticalflip:
    p: 0.5
  downscale:
    p: 0.15
    scale: 0.5
  blur:
    p: 0.2
    blur_limit: 10
  gaussnoise:
    p: 0.25
    var_limit: 50
  colorjitter:
    p: 0.2
    scale_setting: 0.25
    scale_color: 0.1
  superpixels:
    p: 0.1
  zoomblur:
    p: 0.1
  randomsizedcrop:
    p: 0.1
  elastictransform:
    p: 0.2
  normalize:
    mean:
    - 0.5
    - 0.5
    - 0.5
    std:
    - 0.5
    - 0.5
    - 0.5
eval_checkpoint: latest_checkpoint.pth
run_sweep: false
agent: null
dataset_config:
  tissue_types:
    Adrenal_gland: 0
    Bile-duct: 1
    Bladder: 2
    Breast: 3
    Cervix: 4
    Colon: 5
    Esophagus: 6
    HeadNeck: 7
    Kidney: 8
    Liver: 9
    Lung: 10
    Ovarian: 11
    Pancreatic: 12
    Prostate: 13
    Skin: 14
    Stomach: 15
    Testis: 16
    Thyroid: 17
    Uterus: 18
  nuclei_types:
    Background: 0
    Neoplastic: 1
    Inflammatory: 2
    Connective: 3
    Dead: 4
    Epithelial: 5

Again, would really appreciate your help on pointing out what is wrong in my implementation, or a pointer on where to look if I just want to run the model to get the segmentation on a single image. Thank you!

FabianHoerst commented 5 months ago

In general, your code looks like it should work in principle. The return for the cells is the instance_types, which should be a dict containing all cells. However, as this repo was not intended to do single image prediction, I would point you to the following code snippet, which is from the inference script to check all your parts: https://github.com/TIO-IKIM/CellViT/blob/f039eb448b32e04ee0c55e2189c782223e1dc599/cell_segmentation/inference/inference_cellvit_experiment_pannuke.py#L654

Additionally, we do not rescale the images to 224 px size, but rather use 256, as can be seen here: https://github.com/TIO-IKIM/CellViT/blob/f039eb448b32e04ee0c55e2189c782223e1dc599/cell_segmentation/inference/inference_cellvit_experiment_pannuke.py#L281

Additionally, I cannot check if you load the model the correct way, but you could check this by comparing your model loading to the following section: https://github.com/TIO-IKIM/CellViT/blob/f039eb448b32e04ee0c55e2189c782223e1dc599/cell_segmentation/inference/inference_cellvit_experiment_pannuke.py#L150