mazurowski-lab / finetune-SAM

This is an official repo for fine-tuning SAM to customized medical images.
https://arxiv.org/abs/2404.09957
Apache License 2.0
136 stars 22 forks source link

About the args "multimask_output=True" in sam.mask_decoder #14

Closed happyday521 closed 2 months ago

happyday521 commented 4 months ago

When I set "multimask_output=False", the code will raise errors wrt cuDNN. (e.g., [runtimeError: Unable to find a valid cuDNN algorithm to run convolution). When it is True, the code can run normally.

Can I set this args to False in your code? Can you know the possible reason for this problem? Thanks!

surpasss commented 3 months ago

did you have solved it? Thansks!

happyday521 commented 3 months ago

did you have solved it? Thansks!

Not yet.

Guhanxue commented 2 months ago

Hi, thanks for reaching out. It seems that this argument isn't available in our current codebase; you can verify this by checking the args.py file.

Regarding orginal SAM's multimask_output, it generates three different predictions for the same 'object' or 'class' to deal with amibiguity when providing box/point prompts.

When fine-tuning the model in automatic mode, our code handles the output channels differently. Specifically, we adjust the number of output channels to match the number of classes (with one channel reserved for the background class):

sam = sam_model_registry[args.arch](args, checkpoint=os.path.join(args.sam_ckpt), num_classes=args.num_cls)

For instance, if you're working on a binary segmentation task, you would set args.num_cls=2, resulting in an output with 2 channels. If you set multimask_output=False, only the first output channel (typically the background channel) would be the pred output. This configuration would not align with the following loss function:

loss = criterion1(pred, msks.float()) + criterion2(pred, torch.squeeze(msks.long(), 1))

This mismatch could lead to a cuDNN error.

It would be really helpful if you could provide more details on why you prefer a single-channel output. I'd be happy to assist in adjusting the code (perhaps the loss function) to accommodate this need.

happyday521 commented 2 months ago

Got it. Thanks very much!