yogendra-yatnalkar / SAM-Promptless-Task-Specific-Finetuning

Promtless-TaskSpecific-Finetuning of MetaAI Segment-Anything Model
Apache License 2.0
5 stars 0 forks source link

Loading saved decoder #1

Open ansuehu opened 4 months ago

ansuehu commented 4 months ago

Hi!

First of all, I would like to thank you for the great work, it's amazing how well it works. But once I save the model, how can I load it?

Thanks in advance!

yogendra-yatnalkar commented 4 months ago

Hi @ansuehu, glad you liked it. For saving the model, you will have to save the encoder and decoder seperately.

The encoder has not been changed, so you can directly download it from torch-hub instead of saving it (optional).

# Loadign the SAM model
sam = sam_model_registry["vit_b"](checkpoint="/kaggle/working/sam_vit_b_01ec64.pth")

For the deocder, save it using the state_dict method shown here: https://pytorch.org/tutorials/beginner/saving_loading_models.html

while loading it back, use the original sam to define the decoder first and then load its weights from the saved state_dict()

sam_decoder = SAM_Decoder(sam_encoder = sam.image_encoder, sam_preprocess = sam.preprocess)

Please let me know if you face any issues, will be happy to help.