facebookresearch / sam2

The repository provides code for running inference with the Meta Segment Anything Model 2 (SAM 2), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
12.77k stars 1.2k forks source link

Not able to use model with Fine Tuned Model weights #246

Open owenip opened 3 months ago

owenip commented 3 months ago

Hi, I am having problem with loading a fine tuned model weight. Perhaps someone with fine tuning SAM2 experience can shine some light on this.

The fine tuned model is working fine right after training loop but not able to segment anything if loading the same weights from torch file. I have reran and double check the loaded weights.

Below is an example of model output. Only single bounding box prompt is being used. the fine tuned model from training loop is able to segment the target but the model with loading fine tuned via model state_dict() is not.

Model 1 is the fine tuned model from training loop Model 2 is the model with loading fine tuned model state_dict()

lzl2040 commented 2 months ago

Have you solved this problem? I also met this problem.

arawxx commented 2 months ago

OH GOD. I've been dealing with this problem for a whole week now and it's driving me insane. I could not solve it yet. I hypothesize it has something to do with the memory block...

owenip commented 2 months ago

It's driving me crazy as well. I have no idea where or what went wrong

heyoeyo commented 2 months ago

I might be misunderstanding the setup, but if they're supposed to be the same, then it seems there's something going wrong with the hi-res embedding. As a sanity check, maybe it's worth disabling the use of the hi-res embeddings (by setting the use_high_res_features_in_sam config to False) to see if that avoids the mask output breaking. It would also be interesting to turn off the +/- 32 clamping to see if the output mask is generating a reasonable pattern but with overly negatively values or whether it's going off to negative infinity? If it's outputting infinity, then it may be a data type/numerical issue, in which case switching to float32 could help if it's not already used.

I hypothesize it has something to do with the memory block...

At least for image segmentation, you can disable the use of memory features on the image encoder by turning off the directly_add_no_mem_embed config setting.

bidulgi123 commented 1 week ago

If you use torch.cuda.amp.autocast during training and prediction, try changing it to torch.cuda.amp.autocast(enabled=False). It is most likely a mixed precision issue.