kevinjohncutler / omnipose

Omnipose: a high-precision solution for morphology-independent cell segmentation
https://omnipose.readthedocs.io
Other
96 stars 29 forks source link

3D Segmentation with Omnipose using 2D Cellpose models #98

Closed mrdandelion6 closed 1 month ago

mrdandelion6 commented 3 months ago

I am trying to segment 3D tiff cell images with Omnipose but am running into a shape mismatch error. I am following the example code on the Omnipose documentation and have even tested with the same file they used. Note that I am not testing 3D Omnipose models, but rather 2D Cellpose models with do_3D=True and omni=True.

This is the start of the guide I am following on the Omnipose docs. image

Task

Want to be able to segment 3D tiff files using Omnipose mask building but with Cellpose 2D models.

The Error:

When I attempt to segment 3D images using Omnipose's 2D cellpose models (by setting do_3D=True and omni=True), I get a shape mismatch error. Particularly, I am forced to configure the model with nclasses=3 even though they did nclasses=2 in the documentation example. I am forced to do this because I get a separate error otherwise (more on that below). As such, I believe this is the root cause of the shape mismatch error I am experiencing. The error occurs when calling model.eval. Here is the last part of the error message:

ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (3,2)  and requested shape (4,2)

Here is a link to a txt file with the full error message.

Note that my code is identical to the Omnipose documentation example for 3D segmentation with do_3D aside from two things. First, I have set gpu=True instead of gpu=False in my model configuration. Second, I have nclasses=3 instead of nclasses=2 in my model configuration. This is because when I set nclasses=2, I get a separate error:

RuntimeError: Error(s) in loading state_dict for CPnet:
    size mismatch for output.2.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([2, 32, 1, 1]).
    size mismatch for output.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([2]).

You can see the full error I get when setting nclasses=2 at this link. Note that this error only occurs when gpu=True is set. When I have gpu=False, I get a similar output but it does not produce an error:

failed to load model Error(s) in loading state_dict for CPnet:
    size mismatch for output.2.weight: copying a param with shape torch.Size([3, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([2, 32, 1, 1]).
    size mismatch for output.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([2]).

Then my only choice seems to be having nclasses=3, as I want to use my gpu. Here are the lines of code which generate the shape mismatch error when I configure the model when nclasses=3.

from cellpose_omni import models, core

model_name = "plant_cp"

use_GPU = core.use_gpu()
model = models.CellposeModel(gpu=use_GPU, 
                             model_type=model_name, 
                             net_avg=False, 
                             diam_mean=0, 
                             nclasses=3, 
                             dim=2, 
                             nchan=2)
masks_cp, flows_cp, _ = model.eval(imgs,
                                   channels=[0,0],
                                   rescale=None,
                                   mask_threshold=0,
                                   net_avg=0,
                                   transparency=True, 
                                   flow_threshold=0.,
                                   verbose=0, 
                                   tile=0,
                                   compute_masks=1, 
                                   do_3D=True, 
                                   omni=1,
                                   flow_factor=10)

You can see the full code for the above snippet here.

I have played around with the parameters for both model configuration and the model.eval function but have not been able to resolve the issue.

Specification

I am running the script remotely on Narval clusters with Omnipose 0.4.4. I am using CUDA 11.8 and Python 3.10.2. The OS is Gentoo Base System release 2.6 provided the Narval clusters. Here are the GPUs I am using on the cluster:

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:01:00.0 Off |                    0 |
| N/A   30C    P0             51W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          Off |   00000000:81:00.0 Off |                    0 |
| N/A   30C    P0             50W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA A100-SXM4-40GB          Off |   00000000:C1:00.0 Off |                    0 |
| N/A   29C    P0             49W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

Note that I am using a virtual environment instead of Conda. The compute nodes I am using do not allow Conda. Here is the result of running pip list, showing all my installed dependencies.

Thank you,

kevinjohncutler commented 3 months ago

Ouch, that spelling mistake on my part! They do crop up everywhere...

How exciting! This is the first time I have seen anyone try to use that code. And as it turns out, I made some changes to the eval loop (handling dims, channels, etc) that messed up the do_3D branch. I just pushed the changes to fix that and tested with that notebook. Works great on GPU.

kevinjohncutler commented 1 month ago

@mrdandelion6 closing now, please reopen if you have any further issues.