Open SirRob1997 opened 1 year ago
Adding the --model-type=transformer -O1
flag and heavily downscaling model and batch size allows the model to successfully compile but after that it does seem that there are many more compile calls that are uncached (for every forward), so the model spends more time compiling than actually training. In fact, I never get an actual training log output.
For too many graphs with Fairseq library for XLA, here are some of our findings.
XLA flow requires TPU flag - The CLI arg must be used to trigger the torch-xla flow for pre- training.
Dynamic input size triggers recompilation - By default, the FairSeq library uses inputs with dynamic sequence lengths for pre-training. This is a problem because each new sequence length that's passed to the model triggers a new re- compilation. In the worst case, the total number of re-compilations will be equal to the maximum sequence length (for example 512 recompilations). Excessive re-compilation greatly increases the overall time-to-train and memory requirements. The amount of re-compilation due to dynamic input sequence lengths can be significantly reduced by restricting which input sequence lengths are used for pre-training. This can be accomplished by padding inputs to a single, or a few, predefined sequence lengths. You can set this argument requried_seq_len_multiple
: https://github.com/facebookresearch/fairseq/blob/da8fb630880d529ab47e53381c30ddc8ad235216/fairseq/dataclass/configs.py#L485 to pad to max seq-length.
Here is the issue for reference: https://github.com/facebookresearch/fairseq/issues/4198
We also see that at line: https://github.com/facebookresearch/fairseq/blob/main/fairseq/trainer.py#L946, the library performs an or
between xla tensor and float (1.0). This is done to avoid divide by 0. This causes a copy of xla tensor from device to cpu, and hence a graph cut. This can be replaced by
doing torch.max(sample_size, torch.tensor(1.0)) . Hence, no more graphs at this line.
Finally, we provide a parallel_compile utility: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/training/pytorch-neuron-parallel-compile.html?highlight=neuron_parallel_compile#pytorch-neuron-neuron-parallel-compile-cli-torch-neuronx . This would compile all the graphs in parallel and cache them for you. So that the actual training run doesn't spend time in compilation.
Let us know if this improves the number of graphs issues.
To replicate the compiler failure, can you share model-type and task type with the hyper-parameters would help.
Sure, here is the command I've been using (I've stripped a few internal flags and reduced the number of language pairs):
ulimit -n 65536; fairseq-train data-bin/shard_0/epoch_0 --log-format simple --log-interval 100 --save-dir /mnt/task_wrapper/user_output/artifacts --max-tokens 4096 --update-freq 1 --num-workers 1 --restore-file /does_not_exist.pt --adam-betas "(0.9, 0.98)" --arch transformer --clip-norm 3.0 --criterion label_smoothed_cross_entropy --ddp-backend no_c10d --decoder-attention-heads 8 --decoder-embed-dim 512 --decoder-ffn-embed-dim 8192 --decoder-layers 6 --dropout 0.1 --encoder-attention-heads 8 --encoder-embed-dim 512 --encoder-ffn-embed-dim 8192 --encoder-layers 6 --eval-bleu-print-samples --bf16 --label-smoothing 0.1 --lang-pairs en_US-zh_CN,en_US-ja_JP,en_US-ko_KR,en_US-it_IT,en_US-es_ES,en_US-pt_BR --lr 0.0004 --lr-scheduler inverse_sqrt --max-source-positions 1024 --max-target-positions 1024 --max-update 100000 --no-epoch-checkpoints --optimizer adam --save-interval-updates 2000 --seed 1 --share-all-embeddings --skip-invalid-size-inputs-valid-test --stop-min-lr 1e-09 --task translation_multi_simple_epoch --validate-interval 65536 --validate-interval-updates 2000 --warmup-init-lr 1e-07 --warmup-updates 4000 --weight-decay 0.005 --tpu
Thanks for all the pointers, I'll try them! Especially the dynamic length is an important one and might explain the behaviour I was seeing.
Thank you for the command, we are now trying to replicate the issue. Will report back once we have reproduced the issue on our end.
I'm trying to run a workload using fairseq that should already have XLA support by setting
--tpu
and replacing--fp16
with --bf16
(https://github.com/facebookresearch/fairseq/tree/main) but I'm running into the following error that happens during one of the threeneuronx-cc
compile calls which seems odd since the previous 2 calls seem to succeed.Any ideas how to resolve?
fairseq
runsxm.mark_step()
for each forward so theoretically the instructions shouldn't be too large.