msr-fiddle / pipedream

MIT License
376 stars 117 forks source link

Multi node training #36

Open ADAM-CT opened 4 years ago

ADAM-CT commented 4 years ago

I hope to add examples of multi node training

ADAM-CT commented 4 years ago

I wonder if there is something wrong with my method? Multi-machine pipedream can hardly be successfully trained, and it will report an error every time. As the first few questions I raised, can anyone help me

deepakn94 commented 4 years ago

Not sure what your method is. We were able to successfully run multi-machine PipeDream experiments.

Some instructions are here: https://github.com/msr-fiddle/pipedream/blob/master/EXPERIMENTS.md. In particular, look for the *_16gpu.yml files, which are all run using multiple servers. Alternatively, please send me the commands you're trying to run, and I can try helping you out.

ADAM-CT commented 4 years ago

thank you for your reply! server1: python -m launch --nnodes 2 --node_rank 0 --nproc_per_node 8 main_with_runtime.py --data_dir ../../data --master_addr 172.17.2.1 --module models.vgg16.gpus=16 --distributed_backend gloo -b 64 --lr 0.040000 --lr_policy polynomial --weight-decay 0.000500 --epochs 20 --print-freq 100 --verbose 100 --num_ranks_in_server 8 --config_path models/vgg16/gpus=16/hybrid_conf.json

server2: python -m launch --nnodes 2 --node_rank 1 --nproc_per_node 8 main_with_runtime.py --data_dir ../../data --master_addr 172.17.2.1 --module models.vgg16.gpus=16 --distributed_backend gloo -b 64 --lr 0.040000 --lr_policy polynomial --weight-decay 0.000500 --epochs 20 --print-freq 100 --verbose 100 --num_ranks_in_server 8 --config_path models/vgg16/gpus=16/hybrid_conf.json

ADAM-CT commented 4 years ago

>Note: The models used next are from your generated models, I just use them to train

when i use --config_path models/vgg16/gpus=16_straight/hybrid_conf.json

it will throw the following error: ValueError: optimizer got an empty parameter list Traceback (most recent call last): File "main_with_runtime.py", line 578, in main() File "main_with_runtime.py", line 229, in main macrobatch=args.macrobatch) File "../sgd.py", line 23, in init macrobatch=macrobatch, File "../optimizer.py", line 41, in init master_parameters, **optimizer_args) File "/opt/conda/lib/python3.6/site-packages/torch/optim/sgd.py", line 64, in init super(SGD, self).init(params, defaults) File "/opt/conda/lib/python3.6/site-packages/torch/optim/optimizer.py", line 45, in init raise ValueError("optimizer got an empty parameter list") ValueError: optimizer got an empty parameter list

when i use --config_path models/vgg16/gpus=16/mp_conf.json

it will throw the following error:

rank_to_stage_map {0: 0, 1: 1} Traceback (most recent call last): File "main_with_runtime.py", line 578, in main() File "main_with_runtime.py", line 191, in main enable_recompute=args.recompute) File "../runtime.py", line 64, in init master_addr, rank, local_rank, num_ranks_in_server) File "../runtime.py", line 150, in initialize assert 0 <= self.rank < len(rank_to_stage_map) AssertionError

but when i use --config_path models/vgg16/gpus=16/hybrid_conf.json

it will successful!!!

So I am very confused, why is there such a result, my environment has not changed in any way, just changed the configuration file

my enviroment: server 1: 8 V100
server2: 8 V100

ADAM-CT commented 4 years ago

I think in theory, taking your model directly can at least train successfully, but something went wrong

ghost commented 4 years ago

If I understand correctly, there are some problems in models/vgg16/gpus=16/mp_conf.json. This configuration file only contains rank 0 and rank 1, which conflicts with the setting of gpu=16 (should contain rank 0 to rank 15).

When I use models/vgg16/gpus=16_straight/hybrid_conf.json, I encountered the same problem with @ADAM-CT . After observing the error message, I found that only certain processes had problems (process 0, 2, 4, 6). So I check the corresponding stage*.py of these processes in models/vgg16/gpus=16_straight/ and I found that the commonality of these stages is that none of them contain the layers with trainable parameters.

It works well after I modify the partition and let each stage have at least one layer with trainable parameters.

I think the reason for this problem should be that the current program cannot handle certain stages that do not contain trainable parameters.