Open samyakjain0112 opened 2 years ago
Hi, this repo is basically showing gsam.py and how to use it, the train.py script is forked from another repo. You will need to write your own distributed loader and model as in the docs above. I think the best case is you have a distributed training code in PyTorch already, and replace the optimizer part with gsam.
It's slightly tricky with this PyTorch implementation because I don't have GPU resources to test it. Sometimes the same algorithm could perform slightly different just switching from PyTorch to tf/Jax. The paper was conducted during my internship at Google, and the original implementation was in Jax. You can find the google repo here https://github.com/google-research/big_vision and I left a PR but not merged yet. My experiments heavily depends on google big_vision repo, so in order to publish jax code I had to wait for their release until last month.
Thanks @juntang-zhuang for clarification.
Hi,
I dont see distributed data parallel being used while creating the dataloader and the model, is there any specific reason for this? In general as present in https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html a wrapper needs to be made on the model and a sampler needs to be passed into the dataloader. However these things seem missing?