kevinjohncutler / omnipose

Omnipose: a high-precision solution for morphology-independent cell segmentation
https://omnipose.readthedocs.io
Other
96 stars 29 forks source link

model.eval error when not using GPU #106

Open niwaka-ame opened 2 weeks ago

niwaka-ame commented 2 weeks ago

Hello, first of all many thanks for this amazing package. I'm using the pre-trained bact_fluor_omni model with GPU - this runs fine. However, I wanted to test the speed without GPU and then encountered issue.

I'm running the test in a separate Conda environment with Python 3.10.11, with the latest Omnipose installed from this GitHub repo (instead of PyPI).

The relevant part of my code is (mostly following the example in this repo):

imgs = [io.imread(f) for f in files]
model_name = 'bact_fluor_omni'
model = models.CellposeModel(gpu=False, model_type=model_name)
chans = [0,0] #this means segment based on first channel, no second channel
n = range(3) 

# define parameters
params = {'channels':chans, # always define this with the model
          'rescale': None, # upscale or downscale your images, None = no rescaling
          'mask_threshold': -2, # erode or dilate masks with higher or lower values between -5 and 5
          'flow_threshold': 0, # default is .4, but only needed if there are spurious masks to clean up; slows down output
          'transparency': True, # transparency in flow output
          'omni': True, # we can turn off Omnipose mask reconstruction, not advised
          'cluster': True, # use DBSCAN clustering
          'resample': True, # whether or not to run dynamics on rescaled grid or original grid
          'verbose': False, # turn on if you want to see more output
          'tile': False, # average the outputs from flipped (augmented) images; slower, usually not needed
          'niter': None, # default None lets Omnipose calculate # of Euler iterations (usually <20) but you can tune it for over/under segmentation
          'augment': False, # Can optionally rotate the image and average network outputs, usually not needed
          'affinity_seg': False, # new feature, stay tuned...
         }

masks, flows, styles = model.eval([imgs[i] for i in n],**params)

The trackback is:

2024-09-18 16:41:44,687 [INFO]     cellpose_omni/models.py                       line 432   >>bact_fluor_omni<< model set to be used
2024-09-18 16:41:44,703 [INFO]     cellpose_omni/core.py         assi...evice()  line  72   Using CPU.
2024-09-18 16:41:44,726 [INFO]                                   __init__....()  line 163   u-net config: ([2, 32, 64, 128, 256], 4, 2)
2024-09-18 16:41:44,755 [INFO]     cellpose_omni/utils.py        flush.......()  line  47   0%|          | 0/3 [00:00<?, ?it/s]
2024-09-18 16:41:44,830 [INFO]                                                   line  47   0%|          | 0/3 [00:00<?, ?it/s]
-------------------------------------------------------------------------
AttributeError                          Traceback (most recent call last)
File ~/code/napari-gmm/napari_gmm/test_omnipose.py:81
     79 
     80 tic = time.time()
---> 81 masks, flows, styles = model.eval([imgs[i] for i in n],**params)
     82 
     83 net_time = time.time() - tic

File ~/code/omnipose/cellpose_omni/models.py:1065, in CellposeModel.eval(self, x, batch_size, indices, channels, channel_axis, z_axis, normalize, invert, rescale, diameter, do_3D, anisotropy, net_avg, augment, tile, tile_overlap, bsize, num_workers, resample, interp, cluster, suppress, boundary_seg, affinity_seg, despur, flow_threshold, mask_threshold, diam_threshold, niter, cellprob_threshold, dist_threshold, flow_factor, compute_masks, min_size, max_size, stitch_threshold, progress, show_progress, omni, calc_trace, verbose, transparency, loop_run, model_loaded, hysteresis)
   1059 rsc = rescale[i] if isinstance(rescale, list) or isinstance(rescale, np.ndarray) else rescale
   1060 chn = channels if channels is None else channels[i] if (len(channels)==len(x) and 
   1061                                                         (isinstance(channels[i], list) 
   1062                                                          or isinstance(channels[i], np.ndarray)) and
   1063                                                         len(channels[i])==2) else channels
-> 1065 maski, stylei, flowi = self.eval(x[i], 
   1066                                  batch_size=batch_size, 
   1067                                  channels = chn,
   1068                                  channel_axis=channel_axis, 
   1069                                  z_axis=z_axis, 
   1070                                  normalize=normalize, 
   1071                                  invert=invert,
   1072                                  rescale=rsc,
   1073                                  diameter=dia, 
   1074                                  do_3D=do_3D, 
   1075                                  anisotropy=anisotropy, 
   1076                                  net_avg=net_avg, 
   1077                                  augment=augment, 
   1078                                  tile=tile, 
   1079                                  tile_overlap=tile_overlap,
   1080                                  bsize=bsize,
   1081                                  resample=resample, 
   1082                                  interp=interp,
   1083                                  cluster=cluster,
   1084                                  suppress=suppress,
   1085                                  boundary_seg=boundary_seg,
   1086                                  affinity_seg=affinity_seg,
   1087                                  despur=despur,
   1088                                  mask_threshold=mask_threshold, 
   1089                                  diam_threshold=diam_threshold,
   1090                                  flow_threshold=flow_threshold, 
   1091                                  niter=niter,
   1092                                  flow_factor=flow_factor,
   1093                                  compute_masks=compute_masks, 
   1094                                  min_size=min_size, 
   1095                                  max_size=max_size,
   1096                                  stitch_threshold=stitch_threshold, 
   1097                                  progress=progress,
   1098                                  show_progress=show_progress,
   1099                                  omni=omni,
   1100                                  calc_trace=calc_trace, 
   1101                                  verbose=verbose,
   1102                                  transparency=transparency,
   1103                                  loop_run=(i>0),
   1104                                  model_loaded=model_loaded)
   1105 masks.append(maski)
   1106 flows.append(flowi)

File ~/code/omnipose/cellpose_omni/models.py:1157, in CellposeModel.eval(self, x, batch_size, indices, channels, channel_axis, z_axis, normalize, invert, rescale, diameter, do_3D, anisotropy, net_avg, augment, tile, tile_overlap, bsize, num_workers, resample, interp, cluster, suppress, boundary_seg, affinity_seg, despur, flow_threshold, mask_threshold, diam_threshold, niter, cellprob_threshold, dist_threshold, flow_factor, compute_masks, min_size, max_size, stitch_threshold, progress, show_progress, omni, calc_trace, verbose, transparency, loop_run, model_loaded, hysteresis)
   1154 rescale = self.diam_mean / diameter if (rescale is None and (diameter is not None and diameter>0)) else rescale
   1155 rescale = 1.0 if rescale is None else rescale
-> 1157 masks, styles, dP, cellprob, p, bd, tr, affinity, bounds  = self._run_cp(x, 
   1158                                                                           compute_masks=compute_masks,
   1159                                                                           normalize=normalize,
   1160                                                                           invert=invert,
   1161                                                                           rescale=rescale, 
   1162                                                                           net_avg=net_avg, 
   1163                                                                           resample=resample,
   1164                                                                           augment=augment, 
   1165                                                                           tile=tile, 
   1166                                                                           tile_overlap=tile_overlap,
   1167                                                                           bsize=bsize,
   1168                                                                           mask_threshold=mask_threshold, 
   1169                                                                           diam_threshold=diam_threshold,
   1170                                                                           flow_threshold=flow_threshold,
   1171                                                                           niter=niter,
   1172                                                                           flow_factor=flow_factor,
   1173                                                                           interp=interp,
   1174                                                                           cluster=cluster,
   1175                                                                           suppress=suppress,
   1176                                                                           boundary_seg=boundary_seg,  
   1177                                                                           affinity_seg=affinity_seg,
   1178                                                                           despur=despur,
   1179                                                                           min_size=min_size, 
   1180                                                                           max_size=max_size,
   1181                                                                           do_3D=do_3D, 
   1182                                                                           anisotropy=anisotropy,
   1183                                                                           stitch_threshold=stitch_threshold,
   1184                                                                           omni=omni,
   1185                                                                           calc_trace=calc_trace,
   1186                                                                           show_progress=show_progress,
   1187                                                                           verbose=verbose)
   1189 # the flow list stores: 
   1190 # (1) RGB representation of flows
   1191 # (2) flow components
   (...)
   1198 
   1199 # 5-8 were added in Omnipose, hence the unusual placement in the list. 
   1200 flows = [plot.dx_to_circ(dP,transparency=transparency) 
   1201          if self.nclasses>1 else np.zeros(cellprob.shape+(3+transparency,),np.uint8),
   1202          dP, cellprob, p, bd, tr, affinity, bounds]

File ~/code/omnipose/cellpose_omni/models.py:1287, in CellposeModel._run_cp(self, x, compute_masks, normalize, invert, rescale, net_avg, resample, augment, tile, tile_overlap, bsize, mask_threshold, diam_threshold, flow_threshold, niter, flow_factor, min_size, max_size, interp, cluster, suppress, boundary_seg, affinity_seg, despur, anisotropy, do_3D, stitch_threshold, omni, calc_trace, show_progress, verbose, pad)
   1285     else:
   1286         img = zoom(img,rescale,order=1)
-> 1287 yf, style = self._run_nets(img, net_avg=net_avg,
   1288                            augment=augment, tile=tile,
   1289                            normalize=normalize, 
   1290                            tile_overlap=tile_overlap, 
   1291                            bsize=bsize)
   1292 # unpadding 
   1293 yf = yf[unpad+(Ellipsis,)]

File ~/code/omnipose/cellpose_omni/core.py:412, in UnetModel._run_nets(self, img, net_avg, augment, tile, normalize, tile_overlap, bsize, return_conv, progress)
    409 for j in range(len(self.pretrained_model)):
    411     if self.torch and self.gpu:
--> 412         net = self.net.module
    413     else:
    414         net = self.net

File ~/anaconda3/envs/omnipose/lib/python3.10/site-packages/torch/nn/modules/module.py:1729, in Module.__getattr__(self, name)
   1727     if name in modules:
   1728         return modules[name]
-> 1729 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

AttributeError: 'CPnet' object has no attribute 'module'

I'm not sure if this is a software bug or just me. Many thanks!