juntang-zhuang / GSAM

PyTorch repository for ICLR 2022 paper (GSAM) which improves generalization (e.g. +3.8% top-1 accuracy on ImageNet with ViT-B/32)
MIT License
138 stars 15 forks source link

Distributed Data Parallel Missing in Model and Dataloder #4

Open samyakjain0112 opened 2 years ago

samyakjain0112 commented 2 years ago

Hi,

I dont see distributed data parallel being used while creating the dataloader and the model, is there any specific reason for this? In general as present in https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html a wrapper needs to be made on the model and a sampler needs to be passed into the dataloader. However these things seem missing?

juntang-zhuang commented 2 years ago

Hi, this repo is basically showing gsam.py and how to use it, the train.py script is forked from another repo. You will need to write your own distributed loader and model as in the docs above. I think the best case is you have a distributed training code in PyTorch already, and replace the optimizer part with gsam.

It's slightly tricky with this PyTorch implementation because I don't have GPU resources to test it. Sometimes the same algorithm could perform slightly different just switching from PyTorch to tf/Jax. The paper was conducted during my internship at Google, and the original implementation was in Jax. You can find the google repo here https://github.com/google-research/big_vision and I left a PR but not merged yet. My experiments heavily depends on google big_vision repo, so in order to publish jax code I had to wait for their release until last month.

samyakjain0112 commented 2 years ago

Thanks @juntang-zhuang for clarification.