mazurowski-lab / finetune-SAM

This is an official repo for fine-tuning SAM to customized medical images.
https://arxiv.org/abs/2404.09957
Apache License 2.0
136 stars 22 forks source link

About the effect of args.if_update_encoder #5

Closed happyday521 closed 6 months ago

happyday521 commented 6 months ago

Hi, thanks for your great work!I have a question about the code.

In SingleGPU_train_finetune_noprompt.py, line 80- 87 image It seems that it has no effect whether I set args.if_update_encoder to True or False? Always only the mask decoder parameters are updated?This makes me a little confused.

Looking forward to your reply!

Guhanxue commented 6 months ago

Hi, thanks for point this out. Sorry for the confusion, bc i mainly used ddp using multi-gpus for training with Image encoder and decoder both; and SingleGPU_train_finetune_noprompt.py was mainly used on Decoder only in my experiments because our A6000's memory limtation.

When you need to update image encoder, you need to remove 'with torch.no_grad():'.

    for i,data in enumerate(trainloader):
        imgs = data['image'].cuda()
        msks = torchvision.transforms.Resize((args.out_size,args.out_size))(data['mask'])
        msks = msks.cuda()
        img_emb= sam.image_encoder(imgs) 
        # get default embeddings 
        sparse_emb, dense_emb = sam.prompt_encoder(
            points=None,
            boxes=None,
            masks=None,
        )
        pred, _ = sam.mask_decoder(
                        image_embeddings=img_emb,
                        image_pe=sam.prompt_encoder.get_dense_pe(), 
                        sparse_prompt_embeddings=sparse_emb,
                        dense_prompt_embeddings=dense_emb, 
                        multimask_output=True,
                      )
        loss_dice = criterion1(pred,msks.float()) 
        loss_ce = criterion2(pred,torch.squeeze(msks.long(),1))
        loss =  loss_dice + loss_ce

        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

I have also updated this script here. Thanks sooooooo much!

happyday521 commented 6 months ago

Thanks for your reply