kevinjohncutler / omnipose

Omnipose: a high-precision solution for morphology-independent cell segmentation
https://omnipose.readthedocs.io
Other
96 stars 29 forks source link

LinAlgError: Singular matrix when training #12

Closed gatoniel closed 2 years ago

gatoniel commented 2 years ago

Hi, I am trying to train my own Omnipose model diretly in python. However, I get a LinAlgError: Singular matrix error. Here is my code:

import os
from glob import glob
from tifffile import imread
import numpy as np

from stardist import fill_label_holes
from cellpose.models import CellposeModel

path = r"D:\training-data\original"
imgs_files = glob(os.path.join(path, "Image*.tiff"))
lbls_files = [f.replace("Image", "segm").replace("tiff", "tif") for f in imgs_files]

imgs = [imread(f)[np.newaxis, ...] for f in imgs_files]
lbls = [fill_label_holes(imread(f)) for f in lbls_files]
print("imgs: ", [i.shape for i in imgs])
print("lbls: ", [i.shape for i in lbls])

model = CellposeModel(
    gpu=True, omni=True, nclasses=4, nchan=1, diam_mean=0,
)

model.train(
    imgs, lbls, save_every=1, n_epochs=100, batch_size=2, save_path=os.path.join(path, "model"),
    channels=0, channel_axis=0,
)

My output is:

imgs: [(1, 214, 259), (1, 214, 259), (1, 214, 259), (1, 214, 259)]
lbls: [(214, 259), (214, 259), (214, 259), (214, 259)]

core.py (972): divide by zero encountered in true_divide

---------------------------------------------------------------------------
LinAlgError                               Traceback (most recent call last)
Input In [6], in <cell line: 1>()
----> 1 model.train(
      2     imgs, lbls, save_every=1, n_epochs=100, batch_size=2, save_path=os.path.join(path, "model"),
      3     channels=0, channel_axis=0,
      4 )

File ~\anaconda3\envs\omnipose\lib\site-packages\cellpose\models.py:1022, in CellposeModel.train(self, train_data, train_labels, train_files, test_data, test_labels, test_files, channels, channel_axis, normalize, save_path, save_every, save_each, learning_rate, n_epochs, momentum, SGD, weight_decay, batch_size, nimg_per_epoch, rescale, min_train_masks, netstr, tyx)
   1020 if channels is None:
   1021     models_logger.warning('channels is set to None, input must therefore have nchan channels (default is 2)')
-> 1022 model_path = self._train_net(train_data, train_labels, 
   1023                              test_data=test_data, test_labels=test_labels,
   1024                              save_path=save_path, save_every=save_every, save_each=save_each,
   1025                              learning_rate=learning_rate, n_epochs=n_epochs, 
   1026                              momentum=momentum, weight_decay=weight_decay, 
   1027                              SGD=SGD, batch_size=batch_size, nimg_per_epoch=nimg_per_epoch, 
   1028                              rescale=rescale, netstr=netstr,tyx=tyx)
   1029 self.pretrained_model = model_path
   1030 return model_path

File ~\anaconda3\envs\omnipose\lib\site-packages\cellpose\core.py:974, in UnetModel._train_net(self, train_data, train_labels, test_data, test_labels, save_path, save_every, save_each, learning_rate, n_epochs, momentum, weight_decay, SGD, batch_size, nimg_per_epoch, rescale, netstr, do_autocast, tyx)
    972 rsc = diam_train[inds] / self.diam_mean if rescale else np.ones(len(inds), np.float32)
    973 # now passing in the full train array, need the labels for distance field
--> 974 imgi, lbl, scale = transforms.random_rotate_and_resize(
    975                         [train_data[i] for i in inds], Y=[train_labels[i] for i in inds],
    976                         rescale=rsc, scale_range=scale_range, unet=self.unet, tyx=tyx,
    977                         inds=inds, omni=self.omni, dim=self.dim, nchan=self.nchan)
    978 if self.unet and lbl.shape[1]>1 and rescale:
    979     lbl[:,1] /= diam_batch[:,np.newaxis,np.newaxis]**2

File ~\anaconda3\envs\omnipose\lib\site-packages\cellpose\transforms.py:839, in random_rotate_and_resize(X, Y, scale_range, gamma_range, tyx, do_flip, rescale, unet, inds, omni, dim, nchan, kernel_size)
    837     if tyx is None:
    838         tyx = (L,)*dim if dim==2 else (8*n,)+(8*n,)*(dim-1) #must be divisible by 2**3 = 8
--> 839     return omnipose.core.random_rotate_and_resize(X, Y=Y, scale_range=scale_range, gamma_range=gamma_range,
    840                                                   tyx=tyx, do_flip=do_flip, rescale=rescale, inds=inds, nchan=nchan)
    841 else:
    842     # backwards compatibility; completely 'stock', no gamma augmentation or any other extra frills. 
    843     # [Y[i][1:] for i in inds] is necessary because the original transform function does not use masks (entry 0). 
    844     # This used to be done in the original function call. 
    845     if tyx is None:

File ~\anaconda3\envs\omnipose\lib\site-packages\omnipose\core.py:1352, in random_rotate_and_resize(X, Y, scale_range, gamma_range, tyx, do_flip, rescale, inds, nchan)
   1347     y = None if Y is None else Y[n]
   1348     # use recursive function here to pass back single image that was cropped appropriately 
   1349     # # print(y.shape)
   1350     # skimage.io.imsave('/home/kcutler/DataDrive/debug/img_orig.png',img[0])
   1351     # skimage.io.imsave('/home/kcutler/DataDrive/debug/label_orig.tiff',y[n]) #so at this point the bad label is just fine 
-> 1352     imgi[n], lbl[n], scale[n] = random_crop_warp(img, y, nt, tyx, nchan, scale[n], 
   1353                                                  rescale is None if rescale is None else rescale[n], 
   1354                                                  scale_range, gamma_range, do_flip, 
   1355                                                  inds is None if inds is None else inds[n], dist_bg)
   1357 return imgi, lbl, np.mean(scale)

File ~\anaconda3\envs\omnipose\lib\site-packages\omnipose\core.py:1455, in random_crop_warp(img, Y, nt, tyx, nchan, scale, rescale, scale_range, gamma_range, do_flip, ind, dist_bg, depth)
   1453 c_in = 0.5 * np.array(s) + dxy
   1454 c_out = 0.5 * np.array(tyx)
-> 1455 offset = c_in - np.dot(np.linalg.inv(M), c_out)
   1457 # M = np.vstack((M,offset))
   1458 mode = 'reflect'

File <__array_function__ internals>:180, in inv(*args, **kwargs)

File ~\anaconda3\envs\omnipose\lib\site-packages\numpy\linalg\linalg.py:545, in inv(a)
    543 signature = 'D->D' if isComplexType(t) else 'd->d'
    544 extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
--> 545 ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
    546 return wrap(ainv.astype(result_t, copy=False))

File ~\anaconda3\envs\omnipose\lib\site-packages\numpy\linalg\linalg.py:88, in _raise_linalgerror_singular(err, flag)
     87 def _raise_linalgerror_singular(err, flag):
---> 88     raise LinAlgError("Singular matrix")

LinAlgError: Singular matrix

I wonder why there is a divide by zero warning and why the Singular matrix error occurs. Am I initializing the model wrongly? Best, Niklas

kevinjohncutler commented 2 years ago

@gatoniel, I just tried training on my own data using the same code as above (with minor tweaks) and I cannot reproduce the issue. In my case, my images are also mono-channel, and the training works whether or not we use [np.newaxis] to give it an explicit channel axis. In other words, the training does work with shape output (thanks for including that in your debugging!) (1,Ly,Lx) for images and (Ly,Lx) for masks. So I can only think of the following: do you have the latest github version of Omnipose? And what version of python and numpy do you have?

datqduong commented 1 year ago

I ran into this problem as well, when I tried this command with my own data

python -m omnipose --train --verbose --use_gpu --dir "~/data" --pretrained_model bact_fluor_omni --learning_rate 0.1 --n_epochs 100 --mask_filter _masks --img_filter _img --min_train_masks 5 In the "__main__.py" file here (https://github.com/kevinjohncutler/cellpose-omni/blob/4507c365aee6417bc6546c93698a181693f69612/cellpose/__main__.py#L434)

            else:
                rescale = True
                args.diameter = szmean 
                logger.info('pretrained model %s is being used'%cpmodel_path)
                args.residual_on = 1
                args.style_on = 1
                args.concatenation = 0
            if rescale and args.train:
                logger.info('during training rescaling images to fixed diameter of %0.1f pixels'%args.diameter)

I think the problem seems to be the rescale=True when diam_mean=0 which causes the rsc matrix to be inf. If I train any bacteria model, rescale=True the when szmean=0. When I change rescale to False. The training works normally.

Please have a look into it @kevinjohncutler Thanks