czg1225 / SlimSAM

SlimSAM: 0.1% Data Makes Segment Anything Slim
Apache License 2.0
246 stars 14 forks source link

Questions about multi-GPU training #16

Open ranpin opened 2 months ago

ranpin commented 2 months ago

Hi, When I use your code for training, for example _prune_distillstep1.py, I use the command CUDA_VISIBLE_DEVICES=0,1 python prune_distill_step1.py --traindata_path < train_data_root> --valdata_path <val_data_root> --prune_ratio <pruning ratio> --epochs <training epochs>.

I didn't change the content of the code in the file(python prune_distill_step1.py), except for setting batchsize=8. However, I found the following picture, and I didn't succeed in training with multiple GPU cards. And it will report insufficient GPU memory. But I did find the part of your code where you write about using multiple GPUs for parallel training likes
model.image_encoder = torch.nn.DataParallel(model.image_encoder) .

Is there a possible problem with the code or am I missing some setting? How should I approach multi-GPU training? Looking forward to your reply, thank you very much!

image

czg1225 commented 2 months ago

Hi @ranpin , We also found this problem. We only use a single GPU for training during our implementation so we were not aware of this problem before. A possible solution is to implement muti-gpus training by torch.distributed but not torch.nn.DataParallel.

ranpin commented 2 months ago

Hi @ranpin , We also found this problem. We only use a single GPU for training during our implementation so we were not aware of this problem before. A possible solution is to implement muti-gpus training by torch.distributed but not torch.nn.DataParallel.

OK, If the current code is indeed not working, would you mind updating the relevant multi-GPU training code when you have time? Because I found that batchsize also has an impact on the time and effectiveness of the model's training and testing, which maybe help your paper's results as well. Also, I will try to implement it myself. Thanks anyway!