jiyuuchc / lacss

A deep learning model for single cell segmentation from microsopy images.
https://jiyuuchc.github.io/lacss/
MIT License
27 stars 4 forks source link

can not save the trained model in the "train_with_point_label.ipynb" #10

Closed lphilomena closed 2 months ago

lphilomena commented 4 months ago

Hello! Many thanks for the nice work. I tried the "train_with_point_label.ipynb". It works well. I got the following error when trying to save the trained model.

The revised code is as following: trainer.do_training( TFDatasetAdapter(ds), n_steps = n_steps, validation_interval = validation_interval, init_vars = dict(params=params), checkpoint_manager=cp_mngr, )

I got the following error:

TypeError Traceback (most recent call last) Cell In[10], line 9 4 logpath = '/home/lacss/notebooks/bright_test/' 5 cp_mngr = orbax.checkpoint.CheckpointManager( 6 logpath, 7 ) ----> 9 trainer.do_training( 10 TFDatasetAdapter(ds), 11 n_steps = n_steps, 12 validation_interval = validation_interval, 13 init_vars = dict(params=params), 14 checkpoint_manager=cp_mngr, 15 )

File ~/anaconda3/envs/lacss/lib/python3.11/site-packages/lacss/train/lacss_trainer.py:313, in LacssTrainer.do_training(self, dataset, val_dataset, n_steps, validation_interval, checkpoint_manager, warmup_steps, sigma, pi, init_vars) 305 next_cp_step = ( 306 (cur_step + validation_interval) 307 // validation_interval 308 * validation_interval 309 ) 311 print(f"Current step {cur_step} going to {next_cp_step}") --> 313 self._train_to_next_interval( 314 next_cp_step - cur_step, 315 trainer, ... 154 lacss.metrics.LoiAP([5, 2, 1]), 155 lacss.metrics.BoxAP([0.5, 0.75]), 156 ]

TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'

I tried to debug into trainer.do_training, but it didn't step into the function "_train_to_next_interval" at lacss_trainer.py:156. Is there any setting I should change to debug or could you please give an example to save the trained model in the "train_with_point_label.ipynb". Thanks again!

jiyuuchc commented 4 months ago

Well, looks like we broke checkpoint interface. Before we fix it for good, could you try the following workaround:

from pathlib import Path
logpath = '/home/lacss/notebooks/bright_test/'
cp_mngr = orbax.checkpoint.CheckpointManager(
  Path(logpath).absolute(),  # orbax doesn't like relative path
)
cpmngr.save(0, args=ocp.args.StandardSave(1) )

# do training

In addition, please note that if you just want a save for inference, you don't need the full model checkpoint, which contains a lot more data than just the model itself (optimizer state etc). Instead call:

trainer.save("model_save.bin")

which simply save a pickled object of the model plus its parameter. This is also the model save format you need if you want to use any of the functions in the lacss.deploy module.

lphilomena commented 3 months ago

trainer.save("model_save.bin") works! Thanks for your reply.

jiyuuchc commented 2 months ago

Resolved