ictnlp-wshugen / annotated-transformer_codes

A partial runnable code repo for annotated-transformer
20 stars 15 forks source link

AssertionError when running on 4 GPU's: len(modules) == len(inputs) #4

Open rubencart opened 4 years ago

rubencart commented 4 years ago

When running on 2 or 3 GPU's everything works fine. When running on 4 GPU's however, during the first epoch:

Epoch step: 1 Loss 9.134843 Tokens per Sec: 161.252618
Traceback (most recent call last):
  File "realworld.py", line 83, in <module>
    MultiGPULossCompute(model.generator, criterion, devices=devices, opt=model_opt))
  File "/export/home1/NoCsBack/hci/rubenc/transformer-v2/transformer/flow.py", line 53, in run_epoch
    for i, batch in enumerate(data_iter):
  File "realworld.py", line 82, in <genexpr>
    run_epoch((rebatch(pad_idx, b) for b in train_iter), model_par,
  File "/export/home1/NoCsBack/hci/rubenc/miniconda3/envs/transfenv/lib/python3.7/site-packages/torchtext/data/iterator.py", line 141, in __iter__
    self.init_epoch()
  File "/export/home1/NoCsBack/hci/rubenc/miniconda3/envs/transfenv/lib/python3.7/site-packages/torchtext/data/iterator.py", line 117, in init_epoch
    self.create_batches()
  File "/export/home1/NoCsBack/hci/rubenc/transformer-v2/transformer/my_iterator.py", line 18, in create_batches
    from realworld import dividable_size
  File "/export/home1/NoCsBack/hci/rubenc/transformer-v2/realworld.py", line 83, in <module>
    MultiGPULossCompute(model.generator, criterion, devices=devices, opt=model_opt))
  File "/export/home1/NoCsBack/hci/rubenc/transformer-v2/transformer/flow.py", line 55, in run_epoch
    loss = loss_compute(out, batch.trg_y, batch.ntokens)
  File "/export/home1/NoCsBack/hci/rubenc/transformer-v2/transformer/multi_gpu_loss_compute.py", line 35, in __call__
    gen = nn.parallel.parallel_apply(generator, out_column)
  File "/export/home1/NoCsBack/hci/rubenc/miniconda3/envs/transfenv/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 37, in parallel_apply
    assert len(modules) == len(inputs)
AssertionError

I suspect this has something to do with https://github.com/pytorch/pytorch/issues/5587 or https://github.com/pytorch/pytorch/issues/11793. Any tips?

rubencart commented 4 years ago

See answer here https://github.com/pytorch/pytorch/issues/11793#issuecomment-553519568, I'll make a PR for what it's worth