facebookresearch / access

Code to reproduce the experiments from the paper.
Other
101 stars 36 forks source link

How to make it supports Multi GPU training? #27

Closed Sanqiang closed 3 years ago

louismartin commented 3 years ago

I think you would need to check the fairseq documentation and maybe modify the input arguments that we feed to fairseq: https://github.com/facebookresearch/access/blob/7b61fbf0bad665798d662e0a90d2a0e451367df6/access/fairseq/base.py#L142

Sanqiang commented 3 years ago

Actually I tried only change the distributed-world-size will raise an exception. After I reviewed the fairseq sourcecode. It seems like I need to do one more thing: I replace

train.main(train_args) into

        port = random.randint(10000, 20000)
        train_args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
        train_args.distributed_rank = None  # set based on device id
        distributed_utils.infer_init_method(train_args)
        print("Launched %s" % train_args)
        torch.multiprocessing.spawn(
            fn=train.distributed_main,
            args=(train_args,),
            nprocs=train_args.distributed_world_size,
        )

Could I confirm is it enough. It works in multi-gpu but seems like the optimization is not efficient.

louismartin commented 3 years ago

Hum yes that's a good point, I'm not sure what would be the exact setup to train optimally on multi-GPUs though, I thing looking at how to do it in fairseq will be the best bet for you as I'm not very knowledgable! Best, Louis

louismartin commented 3 years ago

Hi @Sanqiang , can I close the issue now ?

Nntraveler commented 3 years ago

That also works for me. Before modification, when I call python scripts/train.py an error will be raised.

RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

After replacing train.main(train_args), I can train it on multiple GPUs.

louismartin commented 3 years ago

Closing for now, feel free to comment or open again if you are still having issues.