thomasverelst / dynconv

Code for Dynamic Convolutions: Exploiting Spatial Sparsity for Faster Inference (CVPR2020)
126 stars 14 forks source link

About multi-gpu training #2

Open d-li14 opened 4 years ago

d-li14 commented 4 years ago

Thanks for your awesome work! Is there any idea how multi-gpu training is supported? Because you know training ResNet-101 on ImageNet with a single GPU is unacceptably slow.

thomasverelst commented 4 years ago

Hi, thanks for having a look at the code. I did not test dual-gpu training, and RN101 indeed takes quite some time on single GPU (~2 weeks). I did not do the effort of implementing multi-gpu support, since I had to use the other available GPUs in our lab for other runs/experiments. I suspect some changes are needed in the loss. I was planning to look at it anyway in the coming weeks, I'll let you know!

I also plan to release a trained mobilenetv2 with the optimized CUDA code integrated.

d-li14 commented 4 years ago

Hi, @thomasverelst Thanks for your prompt reply and sharing! I have realized your concern about the computational resource, but two weeks is still a fairly long experimental period :).

Furthermore, I have made attempts towards multi-gpu training by simply wrapping the model with torch.nn.DataParallel, but was stucked in some issues:

Looking forward to your good news! Also congratulations on the upcoming MobileNetV2 CUDA code!

thomasverelst commented 4 years ago

I've pushed a new branch multigpu. I didn't test training accuracy yet, but it runs. I only had problems with gathering the output dict meta. I considered subclassing DataParallel to support meta but decided to just change the internal working so PyTorch wouldn't complain. Note that the pretrained checkpoints are different from the master branch (url in README).

d-li14 commented 4 years ago

Yeah, it seems to work now. I have successfully run this branch with ResNet-32 on CIFAR for fast prototyping (with matched accuracy and reduced FLOPs). As an additional note, the "FLOPs counting to zero" problem can be solved by modifying the following line model = flopscounter.add_flops_counting_methods(model) to model = flopscounter.add_flops_counting_methods(model.module), due to the DataParallel wrapping.

thomasverelst commented 4 years ago

Thanks a lot, that fixed it.