kevinjohncutler / omnipose

Omnipose: a high-precision solution for morphology-independent cell segmentation
https://omnipose.readthedocs.io
Other
96 stars 29 forks source link

Loss not converging #54

Closed natelharrison closed 1 year ago

natelharrison commented 1 year ago

Hi, I am having an issue where the loss for training is not converging and bounces around a loss of 10. I have successfully trained a cellpose model on the same original data, but for cellpose I tiled the images into 1x128x128 tiles and ran the model on that. For omnipose I have cropped the image into 224x515x512 crops resulting in about 18 image mask pairs that look like the line up fine. I tried with a varerity of learning rates like 0.2, 0.1, 0.01 and nothing seems to work so I'm not sure if I am missing something.

Here's an example of my training command and it's output: omnipose --use_gpu --train --dir /clusterfs/fiona/segmentation_curation/training_data/combined_dataset/3D_half_dims/ --mask_filter _masks --n_epochs 4000 --pretrained_model None --learning_rate 0.002 --save_every 50 --save_each --verbose --all_channels --dim 3 --RAdam --batch_size 1024 --diameter 0 --nclasses 3 --tyx 80,128,128 image

Also, on an unrelated note. Would I be able to train a model that works with cellpose, but use omnipose to leverage the distributed training?

natelharrison commented 1 year ago

After letting it run for more epochs the loss is actually converging, but just fluctuates a lot more than I've seen in cellpose. I'm not sure if this is expected behavior or if I am doing something wrong in preprocessing my data or command arguments.

image

kevinjohncutler commented 1 year ago

@natelharrison Sorry for the delay. I'll follow up for more details via email, but a couple things others might find useful:

  1. I have also noticed that omnipose can oscillate a bit more than cellpose. This is down to the loss functions (especially with torchvf), and I recall it being more apparent with 3D data. Cellpose also seems to converge in fewer epochs (albeit with generally lesser accuracy), so I am curious if your loss continued to decrease for >4000 epochs like I usually see in 2D.
  2. You can indeed train a cellpose model with omnipose, but you might need to manually set omni=False in a few places in order for it to do center-seeking flow. omnipose models trained with nclasses=2, dim=2, and nchan=2 work with cellpose.