WZH0120 / SAM2-UNet

SAM2-UNet: Segment Anything 2 Makes Strong Encoder for Natural and Medical Image Segmentation
Apache License 2.0
125 stars 15 forks source link

is SAM-Unet incremental training possible? #25

Open Tony-Duan2020 opened 1 day ago

Tony-Duan2020 commented 1 day ago

Hi,

Thank you for the excellent code. I would like to know if incremental training is possible with SAM2-Unet.

I tried using the checkpoint 'SAM2-UNet-xx.pth' as input for 'hiera_path' in the train.py but I encountered the following errors:

  File "D:\MyProjects\SAM2-UNet\sam2\build_sam.py", line 81, in _load_checkpoint
    sd = torch.load(ckpt_path, map_location="cpu")["model"]
KeyError: 'model'

I hope to perform incremental training on my model. For example, I collected data A and trained a model this time. Later, I want to build on that model to train with data B. What are some good suggestions for this process?"

xiongxyowo commented 1 day ago

Hi, the parameter --hiera-path is set to None (see test.py) when performing testing or incremental training, as the encoder parameters have been overwritten and saved in the new checkpoint. Our original code does not support resuming training. Please refer to issue #20 for guidance on how to perform incremental training. Additionally, since you are continue training on a different dataset, you may only need to save the model's 'state_dict'.