facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.56k stars 6.41k forks source link

Issues computing RoBERTa on TPUv4 POD #5098

Open scheiblr opened 1 year ago

scheiblr commented 1 year ago

🐛 Bug

Computation on a TPUv4 POD crashs with an error. In order to get it run on TPUv4 I had to change one value in the spawn (see https://github.com/facebookresearch/fairseq/compare/main...scheiblr:fairseq:TPUv4). On a single TPUv4 it runs (with the pip installation).

To Reproduce

python3 -m torch_xla.distributed.xla_dist   --tpu=$TPU_NAME   --restart-tpuvm-pod-server   --env=MKL_THREADING_LAYER=GNU   --env=TPUv4=True   -- 
fairseq-train "/home/scheible/data/clean-small" "--distributed-world-size" "8" "--tpu" "--log-format" "json" "--log-interval" "25" "--task" "masked_lm" "--criterion" "masked_lm" "--optimizer" "adam" "--adam-betas" "(0.9,0.98)" "--adam-eps" "1e-6" "--clip-norm" "0.0" "--arch" "roberta_base" "--save-dir" "/home/scheible/checkpoints" "--sample-break-mode" "none" "--tokens-per-sample" "512" "--lr-scheduler" "polynomial_decay" "--lr" "0.0004" "--total-num-update" "100000" "--warmup-updates" "10000" "--dropout" "0.1" "--attention-dropout" "0.1" "--weight-decay" "0.01" "--batch-size" "64" "--update-freq" "16" "--skip-invalid-size-inputs-valid-test" "--save-interval-updates" "100000" "--max-update" "100000

Erorr:

[...]
2023-05-03 20:47:40 10.130.0.13 [1] TypeError: _foreach_mul_() received an invalid combination of arguments - got (list, Tensor), but expected one of:
2023-05-03 20:47:40 10.130.0.13 [1]  * (tuple of Tensors self, Number scalar)
2023-05-03 20:47:40 10.130.0.13 [1]       didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.13 [1]  * (tuple of Tensors self, tuple of Scalars scalars)
2023-05-03 20:47:40 10.130.0.13 [1]       didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.13 [1]  * (tuple of Tensors self, tuple of Tensors other)
2023-05-03 20:47:40 10.130.0.13 [1]       didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.13 [1]
2023-05-03 20:47:40 10.130.0.11 [0] Exception in device=TPU:2: _foreach_mul_() received an invalid combination of arguments - got (list, Tensor), but expected one of:
2023-05-03 20:47:40 10.130.0.11 [0]  * (tuple of Tensors self, Number scalar)
2023-05-03 20:47:40 10.130.0.11 [0]       didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.11 [0]  * (tuple of Tensors self, tuple of Scalars scalars)
2023-05-03 20:47:40 10.130.0.11 [0]       didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.11 [0]  * (tuple of Tensors self, tuple of Tensors other)
2023-05-03 20:47:40 10.130.0.11 [0]       didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.11 [0]
2023-05-03 20:47:40 10.130.0.11 [0] Traceback (most recent call last):
2023-05-03 20:47:40 10.130.0.11 [0]   File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 331, in _mp_start_fn
2023-05-03 20:47:40 10.130.0.11 [0]     _start_fn(index, pf_cfg, fn, args)
2023-05-03 20:47:40 10.130.0.11 [0]   File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 325, in _start_fn
2023-05-03 20:47:40 10.130.0.11 [0]     fn(gindex, *args)
2023-05-03 20:47:40 10.130.0.11 [0]   File "/home/scheible/fairseq/fairseq/distributed/utils.py", line 362, in distributed_main
2023-05-03 20:47:40 10.130.0.11 [0]     main(cfg, **kwargs)
2023-05-03 20:47:40 10.130.0.11 [0]   File "/home/scheible/fairseq/fairseq_cli/train.py", line 205, in main
2023-05-03 20:47:40 10.130.0.11 [0]     valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
2023-05-03 20:47:40 10.130.0.11 [0]   File "/usr/lib/python3.8/contextlib.py", line 75, in inner
2023-05-03 20:47:40 10.130.0.11 [0]     return func(*args, **kwds)
2023-05-03 20:47:40 10.130.0.11 [0]   File "/home/scheible/fairseq/fairseq_cli/train.py", line 331, in train
2023-05-03 20:47:40 10.130.0.11 [0]     log_output = trainer.train_step(samples)
2023-05-03 20:47:40 10.130.0.11 [0]   File "/usr/lib/python3.8/contextlib.py", line 75, in inner
2023-05-03 20:47:40 10.130.0.11 [0]     return func(*args, **kwds)
2023-05-03 20:47:40 10.130.0.11 [0]   File "/home/scheible/fairseq/fairseq/trainer.py", line 946, in train_step
2023-05-03 20:47:40 10.130.0.11 [0]     self.optimizer.multiply_grads(numer / (sample_size or 1.0))
2023-05-03 20:47:40 10.130.0.11 [0]   File "/home/scheible/fairseq/fairseq/optim/fairseq_optimizer.py", line 116, in multiply_grads
2023-05-03 20:47:40 10.130.0.11 [0]     torch._foreach_mul_(grads, c.to(device) if torch.is_tensor(c) else c)
[...]

Expected behavior

I expected the RoBERTa model to be trained.

Environment

scheiblr commented 1 year ago

solved it by reverting the changes of https://github.com/facebookresearch/fairseq/commit/3c1abb59f581bfce68822b6179d1c1b37b304259