Qsingle / LearnablePromptSAM

Try to use the SAM-ViT as the backbone to create the learnable prompt for semantic segmentation
Apache License 2.0
77 stars 13 forks source link

how is num_class decided? #8

Open krishnaadithya opened 1 year ago

krishnaadithya commented 1 year ago

I saw few threads in the issues but I am still a bit confused with the class, so I am training the model to do vein segmentation and the mask is just 0 and 1 class and the resolution is h,w and no channels ( this also throws an error in the model)

can you please give a training example, with a dataset?

Qsingle commented 1 year ago

Thank you for your advice. I can not locate the error in the information you give, can you give me more details about the error you meet? I guess that some label files have a value of 255 or more than the num_classes, so you will meet an error similar to #6.

krishnaadithya commented 1 year ago

first error is:

python train_learnable_sam.py --image ../Medical-SAM-Adapter/data/fundus/all_combined/train/image/ --mask_path ../Medical-SAM-Adapter/data/fundus/all_combined/train/mask/ --model_name vit_b --checkpoint ../Medical-SAM-Adapter/checkpoint/sam/sam_vit_b_01ec64.pth --save_path ../checkpoint/learnablepromptsam/combine/ --lr 0.05 --mix_precision --optimizer sgd --num_classes 1 --device 0
xFormers not available
train with 5149 imgs
Initial learning rate set to:[0.05]
Traceback (most recent call last):
  File "/home/ubuntu/LearnablePromptSAM/train_learnable_sam.py", line 205, in <module>
    main(args)
  File "/home/ubuntu/LearnablePromptSAM/train_learnable_sam.py", line 181, in main
    loss = loss_func(pred, target)
  File "/opt/conda/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/medsam/lib/python3.10/site-packages/monai/losses/dice.py", line 733, in forward
    raise ValueError(
ValueError: the number of dimensions for input and target should be the same, got shape torch.Size([1, 1, 1024, 1024]) and torch.Size([1, 1024, 1024]).
Qsingle commented 1 year ago

first error is: python train_learnable_sam.py --image ../Medical-SAM-Adapter/data/fundus/all_combined/train/image/ --mask_path ../Medical-SAM-Adapter/data/fundus/all_combined/train/mask/ --model_name vit_b --checkpoint ../Medical-SAM-Adapter/checkpoint/sam/sam_vit_b_01ec64.pth --save_path ../checkpoint/learnablepromptsam/combine/ --lr 0.05 --mix_precision --optimizer sgd --num_classes 1 --device 0 xFormers not available train with 5149 imgs Initial learning rate set to:[0.05] Traceback (most recent call last): File "/home/ubuntu/LearnablePromptSAM/train_learnable_sam.py", line 205, in main(args) File "/home/ubuntu/LearnablePromptSAM/train_learnable_sam.py", line 181, in main loss = loss_func(pred, target) File "/opt/conda/envs/medsam/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/opt/conda/envs/medsam/lib/python3.10/site-packages/monai/losses/dice.py", line 733, in forward raise ValueError( ValueError: the number of dimensions for input and target should be the same, got shape torch.Size([1, 1, 1024, 1024]) and torch.Size([1, 1024, 1024]).

The reason is that the output of the model is a 4D Tensor, and the mask (label) is one 3D tensor, you can use mask.unsqueeze(1) to resolve this error. Because the requirement for the calculation of dice loss is the shape of two tensor is same.

Jakkiabc commented 11 months ago

谢谢你的建议。我无法在您提供的信息中找到错误,您能给我更多有关您遇到的错误的详细信息吗?我猜某些标签文件的值为 255 或大于num_classes,因此您会遇到类似于 #6 的错误。 Hello,can you tell me how I fix this problem? "Target 255 is out of bounds."

Qsingle commented 11 months ago

谢谢你的建议。我无法在您提供的信息中找到错误,您能给我更多有关您遇到的错误的详细信息吗?我猜某些标签文件的值为 255 或大于num_classes,因此您会遇到类似于 #6 的错误。 Hello,can you tell me how I fix this problem? "Target 255 is out of bounds."

You can try to use the args --divide to save this problem.