aws-neuron / aws-neuron-sdk

Powering AWS purpose-built machine learning chips. Blazing fast and cost effective, natively integrated into PyTorch and TensorFlow and integrated with your favorite AWS services
https://aws.amazon.com/machine-learning/neuron/
Other
468 stars 154 forks source link

`fairseq` transformer models on Trn1 instances lead to compile error #783

Open SirRob1997 opened 1 year ago

SirRob1997 commented 1 year ago

I'm trying to run a workload using fairseq that should already have XLA support by setting --tpu and replacing --fp16 with --bf16 (https://github.com/facebookresearch/fairseq/tree/main) but I'm running into the following error that happens during one of the three neuronx-cc compile calls which seems odd since the previous 2 calls seem to succeed.

Any ideas how to resolve? fairseq runs xm.mark_step() for each forward so theoretically the instructions shouldn't be too large.

image
SirRob1997 commented 1 year ago

Adding the --model-type=transformer -O1 flag and heavily downscaling model and batch size allows the model to successfully compile but after that it does seem that there are many more compile calls that are uncached (for every forward), so the model spends more time compiling than actually training. In fact, I never get an actual training log output.

aws-rhsoln commented 1 year ago

For too many graphs with Fairseq library for XLA, here are some of our findings.

  1. XLA flow requires TPU flag - The CLI arg must be used to trigger the torch-xla flow for pre- training.

  2. Dynamic input size triggers recompilation - By default, the FairSeq library uses inputs with dynamic sequence lengths for pre-training. This is a problem because each new sequence length that's passed to the model triggers a new re- compilation. In the worst case, the total number of re-compilations will be equal to the maximum sequence length (for example 512 recompilations). Excessive re-compilation greatly increases the overall time-to-train and memory requirements. The amount of re-compilation due to dynamic input sequence lengths can be significantly reduced by restricting which input sequence lengths are used for pre-training. This can be accomplished by padding inputs to a single, or a few, predefined sequence lengths. You can set this argument requried_seq_len_multiple: https://github.com/facebookresearch/fairseq/blob/da8fb630880d529ab47e53381c30ddc8ad235216/fairseq/dataclass/configs.py#L485 to pad to max seq-length. Here is the issue for reference: https://github.com/facebookresearch/fairseq/issues/4198

  3. We also see that at line: https://github.com/facebookresearch/fairseq/blob/main/fairseq/trainer.py#L946, the library performs an or between xla tensor and float (1.0). This is done to avoid divide by 0. This causes a copy of xla tensor from device to cpu, and hence a graph cut. This can be replaced by doing torch.max(sample_size, torch.tensor(1.0)) . Hence, no more graphs at this line.

  4. Finally, we provide a parallel_compile utility: https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/training/pytorch-neuron-parallel-compile.html?highlight=neuron_parallel_compile#pytorch-neuron-neuron-parallel-compile-cli-torch-neuronx . This would compile all the graphs in parallel and cache them for you. So that the actual training run doesn't spend time in compilation.

Let us know if this improves the number of graphs issues.

To replicate the compiler failure, can you share model-type and task type with the hyper-parameters would help.

SirRob1997 commented 1 year ago

Sure, here is the command I've been using (I've stripped a few internal flags and reduced the number of language pairs):

ulimit -n 65536; fairseq-train data-bin/shard_0/epoch_0  --log-format simple --log-interval 100  --save-dir /mnt/task_wrapper/user_output/artifacts  --max-tokens 4096  --update-freq 1  --num-workers 1 --restore-file /does_not_exist.pt  --adam-betas "(0.9, 0.98)" --arch transformer --clip-norm 3.0 --criterion label_smoothed_cross_entropy --ddp-backend no_c10d --decoder-attention-heads 8 --decoder-embed-dim 512 --decoder-ffn-embed-dim 8192 --decoder-layers 6 --dropout 0.1 --encoder-attention-heads 8 --encoder-embed-dim 512 --encoder-ffn-embed-dim 8192 --encoder-layers 6 --eval-bleu-print-samples --bf16 --label-smoothing 0.1 --lang-pairs en_US-zh_CN,en_US-ja_JP,en_US-ko_KR,en_US-it_IT,en_US-es_ES,en_US-pt_BR --lr 0.0004 --lr-scheduler inverse_sqrt --max-source-positions 1024 --max-target-positions 1024 --max-update 100000 --no-epoch-checkpoints --optimizer adam --save-interval-updates 2000 --seed 1 --share-all-embeddings --skip-invalid-size-inputs-valid-test --stop-min-lr 1e-09 --task translation_multi_simple_epoch --validate-interval 65536 --validate-interval-updates 2000 --warmup-init-lr 1e-07 --warmup-updates 4000 --weight-decay 0.005 --tpu

Thanks for all the pointers, I'll try them! Especially the dynamic length is an important one and might explain the behaviour I was seeing.

aws-rhsoln commented 1 year ago

Thank you for the command, we are now trying to replicate the issue. Will report back once we have reproduced the issue on our end.