[...]
2023-05-03 20:47:40 10.130.0.13 [1] TypeError: _foreach_mul_() received an invalid combination of arguments - got (list, Tensor), but expected one of:
2023-05-03 20:47:40 10.130.0.13 [1] * (tuple of Tensors self, Number scalar)
2023-05-03 20:47:40 10.130.0.13 [1] didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.13 [1] * (tuple of Tensors self, tuple of Scalars scalars)
2023-05-03 20:47:40 10.130.0.13 [1] didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.13 [1] * (tuple of Tensors self, tuple of Tensors other)
2023-05-03 20:47:40 10.130.0.13 [1] didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.13 [1]
2023-05-03 20:47:40 10.130.0.11 [0] Exception in device=TPU:2: _foreach_mul_() received an invalid combination of arguments - got (list, Tensor), but expected one of:
2023-05-03 20:47:40 10.130.0.11 [0] * (tuple of Tensors self, Number scalar)
2023-05-03 20:47:40 10.130.0.11 [0] didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.11 [0] * (tuple of Tensors self, tuple of Scalars scalars)
2023-05-03 20:47:40 10.130.0.11 [0] didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.11 [0] * (tuple of Tensors self, tuple of Tensors other)
2023-05-03 20:47:40 10.130.0.11 [0] didn't match because some of the arguments have invalid types: (!list!, !Tensor!)
2023-05-03 20:47:40 10.130.0.11 [0]
2023-05-03 20:47:40 10.130.0.11 [0] Traceback (most recent call last):
2023-05-03 20:47:40 10.130.0.11 [0] File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 331, in _mp_start_fn
2023-05-03 20:47:40 10.130.0.11 [0] _start_fn(index, pf_cfg, fn, args)
2023-05-03 20:47:40 10.130.0.11 [0] File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 325, in _start_fn
2023-05-03 20:47:40 10.130.0.11 [0] fn(gindex, *args)
2023-05-03 20:47:40 10.130.0.11 [0] File "/home/scheible/fairseq/fairseq/distributed/utils.py", line 362, in distributed_main
2023-05-03 20:47:40 10.130.0.11 [0] main(cfg, **kwargs)
2023-05-03 20:47:40 10.130.0.11 [0] File "/home/scheible/fairseq/fairseq_cli/train.py", line 205, in main
2023-05-03 20:47:40 10.130.0.11 [0] valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
2023-05-03 20:47:40 10.130.0.11 [0] File "/usr/lib/python3.8/contextlib.py", line 75, in inner
2023-05-03 20:47:40 10.130.0.11 [0] return func(*args, **kwds)
2023-05-03 20:47:40 10.130.0.11 [0] File "/home/scheible/fairseq/fairseq_cli/train.py", line 331, in train
2023-05-03 20:47:40 10.130.0.11 [0] log_output = trainer.train_step(samples)
2023-05-03 20:47:40 10.130.0.11 [0] File "/usr/lib/python3.8/contextlib.py", line 75, in inner
2023-05-03 20:47:40 10.130.0.11 [0] return func(*args, **kwds)
2023-05-03 20:47:40 10.130.0.11 [0] File "/home/scheible/fairseq/fairseq/trainer.py", line 946, in train_step
2023-05-03 20:47:40 10.130.0.11 [0] self.optimizer.multiply_grads(numer / (sample_size or 1.0))
2023-05-03 20:47:40 10.130.0.11 [0] File "/home/scheible/fairseq/fairseq/optim/fairseq_optimizer.py", line 116, in multiply_grads
2023-05-03 20:47:40 10.130.0.11 [0] torch._foreach_mul_(grads, c.to(device) if torch.is_tensor(c) else c)
[...]
Expected behavior
I expected the RoBERTa model to be trained.
Environment
fairseq Version: main:
PyTorch Version: 1.13
OS (e.g., Linux): Linux
How you installed fairseq (pip, source): source (fork with TPU4 change)
Build command you used (if compiling from source): pip install .
🐛 Bug
Computation on a TPUv4 POD crashs with an error. In order to get it run on TPUv4 I had to change one value in the spawn (see https://github.com/facebookresearch/fairseq/compare/main...scheiblr:fairseq:TPUv4). On a single TPUv4 it runs (with the pip installation).
To Reproduce
Erorr:
Expected behavior
I expected the RoBERTa model to be trained.
Environment
pip
, source): source (fork with TPU4 change)