TIO-IKIM / CellViT

CellViT: Vision Transformers for Precise Cell Segmentation and Classification
https://doi.org/10.1016/j.media.2024.103143
Other
217 stars 33 forks source link

Predicting on 40x slides using 20x model #34

Open bgginley opened 9 months ago

bgginley commented 9 months ago

Describe the bug Not sure if "bug" or if I am doing something incorrectly. When I attempt to predict on a 40x magnification slide using the 20x preprocessing and 20x model, the resultant predictions do not get tiled correctly, do not appear at the correct resolution, and appear slightly shifted. The same parameters work fine for a native 20x slide. This occurs for all three of my slide types (.tiff, .ndpi, .svs).

To Reproduce Steps to reproduce the behavior: For preprocessing I used this: python3 ./preprocessing/patch_extraction/main_extraction.py --config ./example/preprocessing_example.yaml with the following config (I modified wsi_paths to accept a pickled pandas dataframe of wsi paths because my wsis have both multiple extensions and also exist in separate folders):

wsi_paths: /domino/datasets/local/Histo-cloud/BrandonNuclearSegmentation/NuCLS_model/CellViT/CellViT/Bladder_IO_combo_set_notiff.pkl output_path: ./example/output/preprocessing patch_size: 1024 patch_overlap: 6.25 target_mag: 20 processes: 96 overwrite: False normalize_stains: True min_intersection_ratio: 0.05

For the actual predictions I used this: python3 ./cell_segmentation/inference/cell_detection.py \ --batch_size 5 \ --magnification 20.0 \ --model ./models/pretrained/CellViT-SAM-H-x20.pth\ --gpu 3 \ --geojson \ process_dataset \ --wsi_paths /domino/datasets/local/Histo-cloud/ \ --patch_dataset_path /domino/datasets/local/Histo-cloud/BrandonNuclearSegmentation/NuCLS_model/CellViT/CellViT/example/output/preprocessing/ \ --filelist /domino/datasets/local/Histo-cloud/BrandonNuclearSegmentation/NuCLS_model/CellViT/CellViT/cellvit_bladderIO_predictionlist_reduced4_p4-svs.csv \ --wsi_ext svs

Not sure if there are any output files that are helpful to post but if so let me know and I can add them.

Expected behavior I expect the predictions to be normally tiled / stitched / correct resolution.

Screenshots This is what predictions look like when using the above parameters on a natively 20x base magnification slide: image (zoomed) image

and this is what the predictions look like when running the same parameters on a natively 40x base magnification slide: image (zoomed) image

As you can see from the second set of images, the prediction tiles don't overlap correctly, the actual boundaries of the nuclei don't overlap with true nuclei regions (kind of like they are at the wrong resolution) and some of the tiles appear shifted to the right or downward (In the third image, see the top right tissue section as well as the bottom of the tissue).

What have I done wrong here?

P.S. Beautiful algorithm, really nice work guys. The predictions do look pretty spectacular for the slides I can get working.

FabianHoerst commented 9 months ago

I'll need some more time to check this. To be honest, I would not recommend using the x20 models and x20 inference, as the performance is inferior, and I have not put that much effort into the x20 prediction algorithm. Does this occur with arbitray x40 images? Then I could investigate on my own samples.

bgginley commented 9 months ago

Thanks for the quick reply, Fabian. I do believe it happens with arbitrary 40x images, for example I tested the TCGA sample which you use on the main page: image

I did see in your results the performance is inferior, but I eventually need to apply this model to a dataset where there is a mixture of 40x and 20x images, so I'll need to apply the 20x model either way, and I thought to test the performance of a 20x downsample of 40x image. Its also faster to run at 20x, which is also a pretty important criterion for the project I'm working on.

However, given the above, for now I will just call the 40x model for 40x slides and 20x model for the 20x slides. It would still be interesting to see the performance of the 20x model on a 40x image, given the tradeoff with runtime. Please let me know if you find a solution, to me it appears like a scaling issue which I reason may not be too difficult to fix. I will also try and take a look, I'm just under a bit of a time crunch for this project so I thought it may be more expedient to ask you first.

Thanks for the help and great work! -Brandon

FabianHoerst commented 9 months ago

I will investigate in the beginning of the next year, currently I am also running short In terms of time. I'll check if this is related to the preprocessing or the inference pipeline. Do the patches look like they get extracted correctly?

bgginley commented 9 months ago

Yes the patches do get extracted correctly. I dug a little bit and I believe the culprit is these lines (342-360 in cell_detection.py):

wsi_scaling_factor = wsi.metadata["downsampling"]
                    patch_size = wsi.metadata["patch_size"]
                    x_global = int(
                        patch_metadata["row"] * patch_size * wsi_scaling_factor
                        - (patch_metadata["row"] + 0.5) * overlap)
                    y_global = int(
                        patch_metadata["col"] * patch_size * wsi_scaling_factor
                        - (patch_metadata["col"] + 0.5) * overlap)

                    # extract cell information
                    for cell in patch_instance_types.values():
                        if cell["type"] == nuclei_types["Background"]:
                            continue
                        offset_global = np.array([x_global, y_global])
                        centroid_global = cell["centroid"] + np.flip(offset_global)
                        contour_global = cell["contour"] + np.flip(offset_global)
                        bbox_global = cell["bbox"] + offset_global

The solution I found was to apply the upsampling multiplication on the bbox, centroid, and contour instead of solely the global offset:

wsi_scaling_factor = wsi.metadata["downsampling"]

                    patch_size = wsi.metadata["patch_size"]
                    x_global = int(
                        patch_metadata["row"] * patch_size 
                        - (patch_metadata["row"] + 0.5) * overlap
                    )
                    y_global = int(
                        patch_metadata["col"] * patch_size 
                        - (patch_metadata["col"] + 0.5) * overlap
                    )

                    # extract cell information
                    for cell in patch_instance_types.values():
                        if cell["type"] == nuclei_types["Background"]:
                            continue
                        offset_global = np.array([x_global, y_global])
                        centroid_global = (cell["centroid"] + np.flip(offset_global))* wsi_scaling_factor
                        contour_global = (cell["contour"] + np.flip(offset_global))* wsi_scaling_factor
                        bbox_global = (cell["bbox"] + offset_global)* wsi_scaling_factor

I believe this should not impact 40x operations as I assume wsi.metadata["downsampling"] will return 1 in these cases, but I have not exclusively tested it for a range of images.

FabianHoerst commented 8 months ago

Hello bgginley, I finally found the time to look into your solution. I will test it on some images before merging your code changes. Thank you very much for your contribution!