huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.12k stars 27.04k forks source link

[Benchmark] HF Trainer on A100 #15026

Open stas00 opened 2 years ago

stas00 commented 2 years ago

🖥 Benchmarking transformers w/ HF Trainer on a single A100 40GB

We are going to use a special benchmarking tool that will do all the work for us. https://github.com/huggingface/transformers/pull/14934

This is the index post and specific benchmarks are in their own posts below:

  1. fp16 vs bf16 vs tf32 vs fp32
  2. gradient accumulation steps
  3. batch size
  4. gradient checkpointing
  5. optimizers
  6. combining winning strategies ~3x speed improvement!
  7. RTX-3090 vs A100

Note that each benchmark was run only once, so multiple runs and averaging is probably going to give slightly different results. The purpose here though is to see relative differences roughly and not try to give an exact number.

See also the same benchmarks for RTX-3090

stas00 commented 2 years ago

precision: fp16 vs bf16 vs tf32 vs fp32

Main interest: benchmarking the new --bf16 and --tf32 on Ampere/RTX-3090, comparatively to fp16 and fp32 modes.

Benchmark

The benchmark uses 3 different t5 models, and at the end of the section also gpt2. For t5 the main script is:

CUDA_VISIBLE_DEVICES=0 python \
 examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 64 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 40000 --dataloader_num_workers 2 

and now adding one of:

--tf32 0 # fp32
--tf32 0 --fp16
--tf32 0 --bf16
--tf32 1
--tf32 1 --fp16
--tf32 1 --bf16

But we are going to use a special benchmarking tool that will do all the work for us. https://github.com/huggingface/transformers/pull/14934

Important notes:

  1. --tf32 0 --fp16 0 combo is just fp32 (which is the default mode - we don't have this option per se)
  2. I changed --per_device_train_batch_size in the base command from 32 (t5-small) to 16 (t5-base) to 8 (t5-large) to be able to fit into the GPU memory while keeping it as occupied as possible.
  3. I changed --max_train_samples in the base command from 20k (t5-small) to 10k (t5-base) to 5k (t5-large) to give each run about 1-3min of run time so that the benchmark doesn't take too too long, but is long enough to put strain on the card.

*** Setup:

Datetime    : 2022-01-03 22:43:38

Software:
transformers: 4.16.0.dev0
torch       : 1.10.1
cuda        : 11.3
python      : 3.8.12

Hardware:
1 GPUs      : NVIDIA A100-SXM4-40GB, 39.59GB

Benchmark 1: t5-small

Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 272.59 0 2.49
--tf32 1 581.61 113 2.49
--fp16 --tf32 0 643.07 136 2.49
--fp16 --tf32 1 635.24 133 2.49
--bf16 --tf32 0 616.23 126 2.50
--bf16 --tf32 1 612.59 125 2.50

Conclusions:

CUDA_VISIBLE_DEVICES=3 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-small \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 64 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 40000 --dataloader_num_workers 2 ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'|--fp16|--bf16' '--tf32 0|--tf32 1' --report-metric-keys train_loss \
--repeat-times 1 --base-variation '--tf32 0'

Benchmark 2: t5-base

Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 80.10 0 2.21
--tf32 1 214.10 167 2.21
--fp16 --tf32 0 219.20 174 2.21
--fp16 --tf32 1 218.46 173 2.21
--bf16 --tf32 0 214.17 167 2.22
--bf16 --tf32 1 225.44 181 2.22

Conclusions:

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 32 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 20000 --dataloader_num_workers 2 ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'|--fp16|--bf16' '--tf32 0|--tf32 1' --report-metric-keys train_loss \
--repeat-times 1 --base-variation '--tf32 0'

Benchmark 3: t5-large

Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 31.59 0 2.03
--tf32 1 36.13 14 2.03
--fp16 --tf32 0 34.86 10 0.00
--fp16 --tf32 1 36.77 16 0.00
--bf16 --tf32 0 31.35 -1 2.04
--bf16 --tf32 1 31.30 -1 2.04

Conclusions:

CUDA_VISIBLE_DEVICES=3 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-large \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 8 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 5000 --dataloader_num_workers 2 ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'|--fp16|--bf16' '--tf32 0|--tf32 1' --report-metric-keys train_loss \
--repeat-times 1 --base-variation '--tf32 0'

If I use a higher bs=16 instead of 8, bf16 does deliver a better than fp32 performance, but still not on par with fp16:

Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 39.59 0 2.10
--tf32 1 67.24 70 2.10
--fp16 --tf32 0 70.88 79 0.00
--fp16 --tf32 1 70.38 78 0.00
--bf16 --tf32 0 61.37 55 2.12
--bf16 --tf32 1 59.95 51 2.12

It'd be great to know why CUDA doesn't activate some optimization since not everybody is going to run benchmarks, but if you do run benchmarks and find yourself in this situation Eddie Yan proposed that adding --gradient_accumulation_steps to create a much larger batch for the scheduler to step with which should help a lot.

So let's try --per_device_train_batch_size 16 --gradient_accumulation_steps 4 for a total of effective bs=64:

Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 45.60 0 2.35
--tf32 1 86.78 90 2.36
--fp16 --tf32 0 77.47 70 0.00
--fp16 --tf32 1 79.63 75 0.00
--bf16 --tf32 0 75.85 66 2.41
--bf16 --tf32 1 73.19 61 2.41

Both bf16 and tf32 show a much better performance here.

Benchmark 4: gpt2

Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 28.77 0 3.34
--tf32 1 63.51 121 3.34
--fp16 --tf32 0 69.60 142 3.34
--fp16 --tf32 1 69.98 143 3.34
--bf16 --tf32 0 70.37 145 3.34
--bf16 --tf32 1 69.88 143 3.34

Conclusions:

*** The benchmark command line was:

CUDA_VISIBLE_DEVICES=0 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/language-modeling/run_clm.py --model_name_or_path gpt2 \
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
--logging_strategy no --save_strategy no --do_train --max_train_samples 5000 \
--per_device_train_batch_size 16 --num_train_epochs 1 --warmup_steps 8 \
--block_size 512 --report_to none ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'|--fp16|--bf16' '--tf32 0|--tf32 1' --report-metric-keys train_loss \
--repeat-times 1 --base-variation '--tf32 0'

Benchmark 5: gpt2-medium

Variation Train
samples
per
second
Diff
%
Train
loss
--tf32 0 10.60 0 2.98
--tf32 1 24.81 134 2.98
--fp16 --tf32 0 27.67 161 2.99
--fp16 --tf32 1 27.62 160 2.99
--bf16 --tf32 0 27.57 160 2.99
--bf16 --tf32 1 27.55 160 2.99

Conclusions:


CUDA_VISIBLE_DEVICES=3 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/language-modeling/run_clm.py --model_name_or_path gpt2-medium \
--dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 \
--logging_strategy no --save_strategy no --do_train --max_train_samples 2500 \
--per_device_train_batch_size 8 --num_train_epochs 1 --warmup_steps 8 \
--block_size 512 --report_to none ' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'|--fp16|--bf16' '--tf32 0|--tf32 1' --report-metric-keys train_loss \
--repeat-times 1 --base-variation '--tf32 0'
stas00 commented 2 years ago

gradient accumulation steps

Let's choose t5-base model to test with as it's pretty large yet doesn't overflow like t5-large.

Let's measure --gradient_accumulation_steps 1,2,4,8,16,32 with different precision configurations.

Variation Train
samples
per
second
Diff
%
Train
loss
--gradient_accumulation_steps 1 --tf32 0 93.68 0 2.21
--gradient_accumulation_steps 1 --tf32 1 210.53 125 2.21
--gradient_accumulation_steps 1 --tf32 0 --fp16 217.75 132 2.21
--gradient_accumulation_steps 1 --tf32 0 --bf16 224.09 139 2.22
--gradient_accumulation_steps 2 --tf32 0 97.48 4 2.28
--gradient_accumulation_steps 2 --tf32 1 236.39 152 2.28
--gradient_accumulation_steps 2 --tf32 0 --fp16 244.81 161 2.28
--gradient_accumulation_steps 2 --tf32 0 --bf16 246.08 163 2.29
--gradient_accumulation_steps 4 --tf32 0 99.68 6 2.39
--gradient_accumulation_steps 4 --tf32 1 248.24 165 2.40
--gradient_accumulation_steps 4 --tf32 0 --fp16 259.20 177 2.41
--gradient_accumulation_steps 4 --tf32 0 --bf16 263.39 181 2.42
--gradient_accumulation_steps 8 --tf32 0 100.67 7 2.58
--gradient_accumulation_steps 8 --tf32 1 252.45 169 2.58
--gradient_accumulation_steps 8 --tf32 0 --fp16 261.59 179 2.58
--gradient_accumulation_steps 8 --tf32 0 --bf16 267.37 185 2.62
--gradient_accumulation_steps 16 --tf32 0 100.97 8 2.83
--gradient_accumulation_steps 16 --tf32 1 253.68 171 2.84
--gradient_accumulation_steps 16 --tf32 0 --fp16 256.13 173 2.84
--gradient_accumulation_steps 16 --tf32 0 --bf16 274.14 193 2.89

Let's filter out just one subset so that it's easier to compare the gradient accumulation differences alone, so re-running with just bf16 enabled (--tf32 0 --bf16):

Variation Train
samples
per
second
Diff
%
Train
loss
--gradient_accumulation_steps 1 228.41 0 2.22
--gradient_accumulation_steps 2 248.77 9 2.29
--gradient_accumulation_steps 4 263.54 15 2.42
--gradient_accumulation_steps 8 270.12 18 2.62
--gradient_accumulation_steps 16 271.99 19 2.89

Conclusions:

1.

CUDA_VISIBLE_DEVICES=3 python ./scripts/benchmark/trainer-benchmark.py \ --base-cmd \ ' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \ --output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \ --save_strategy no --per_device_train_batch_size 32 --max_source_length 512 \ --max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \ --source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \ --source_prefix "translate English to Romanian: " --warmup_steps 50 \ --max_train_samples 20000 --dataloader_num_workers 2 ' \ --target-metric-key train_samples_per_second --repeat-times 1 --variations \ '--gradient_accumulation_steps 1|--gradient_accumulation_steps 2|--gradient_accumulation_steps 4|--gradient_accumulation_steps 8|--gradient_accumulation_steps 16' \ '--tf32 0|--tf32 1|--tf32 0 --fp16|--tf32 0 --bf16' --report-metric-keys \ train_loss --repeat-times 1

2.

CUDA_VISIBLE_DEVICES=3 python ./scripts/benchmark/trainer-benchmark.py \ --base-cmd \ ' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \ --output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \ --save_strategy no --per_device_train_batch_size 32 --max_source_length 512 \ --max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \ --source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \ --source_prefix "translate English to Romanian: " --warmup_steps 50 \ --max_train_samples 20000 --dataloader_num_workers 2 --tf32 0 --bf16 ' \ --target-metric-key train_samples_per_second --repeat-times 1 --variations \ '--gradient_accumulation_steps 1|--gradient_accumulation_steps 2|--gradient_accumulation_steps 4|--gradient_accumulation_steps 8|--gradient_accumulation_steps 16' \ --report-metric-keys train_loss --repeat-times 1

stas00 commented 2 years ago

batch size

Variation Train
samples
per
second
Diff
%
Train
loss
--per_device_train_batch_size 1 7.77 0 1.90
--per_device_train_batch_size 2 15.51 100 2.01
--per_device_train_batch_size 4 29.66 282 2.09
--per_device_train_batch_size 8 61.16 687 2.16
--per_device_train_batch_size 16 115.84 1392 2.25
--per_device_train_batch_size 32 224.96 2797 2.38

Conclusions:


CUDA_VISIBLE_DEVICES=3 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 5000 --dataloader_num_workers 2 --bf16' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--per_device_train_batch_size 1|--per_device_train_batch_size 2|--per_device_train_batch_size 4|--per_device_train_batch_size 8|--per_device_train_batch_size 16|--per_device_train_batch_size 32' \
--report-metric-keys train_loss --repeat-times 1 
stas00 commented 2 years ago

gradient checkpointing

Variation Train
samples
per
second
Diff
%
Train
loss
--gradient_checkpointing 0 225.67 24 2.30
--gradient_checkpointing 1 182.68 0 2.30

Conclusions:

Let's look at memory:

Variation Train
samples
per
second
Diff
%
Train
loss
Train
mem
gpu
alloc
delta
Train
mem
gpu
peaked
delta
--gradient_checkpointing 0 122.81 35 2.38 2739MB 1155MB
--gradient_checkpointing 1 90.92 0 2.38 2697MB 3229MB

We can clearly see that peak GPU memory is ~2/3 less.

note: I had to half BS in the 2nd benchmark as I was getting OOM.

CUDA_VISIBLE_DEVICES=3 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 32 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 10000 --dataloader_num_workers 2 --bf16' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--gradient_checkpointing 0|--gradient_checkpointing 1' --report-metric-keys \
train_loss --repeat-times 1 

CUDA_VISIBLE_DEVICES=3 python ./scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' examples/pytorch/translation/run_translation.py --model_name_or_path t5-base \
--output_dir output_dir --do_train --label_smoothing 0.1 --logging_strategy no \
--save_strategy no --per_device_train_batch_size 32 --max_source_length 512 \
--max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: " --warmup_steps 50 \
--max_train_samples 5000 --dataloader_num_workers 2 --bf16 --skip_memory_metrics 0' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--gradient_checkpointing 0|--gradient_checkpointing 1' --report-metric-keys \
'train_loss train_mem_gpu_alloc_delta train_mem_gpu_peaked_delta' \
--repeat-times 1
stas00 commented 2 years ago

optimizers

Let's do fp32 first:

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_hf 214.55 2 2.21
--optim adamw_torch 209.72 0 2.21
--optim adafactor 158.56 -24 2.21
--optim adamw_apex_fused 227.96 9 2.21

Observations:

fp16:

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_hf 221.08 5 2.21
--optim adamw_torch 209.85 0 2.21
--optim adafactor 160.69 -23 2.21
--optim adamw_apex_fused 231.71 10 2.21

bf16:

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_hf 221.28 4 2.22
--optim adamw_torch 212.83 0 2.22
--optim adafactor 164.21 -23 2.22
--optim adamw_apex_fused 237.31 12 2.22

Observations:


# fp32
CUDA_VISIBLE_DEVICES=3 python scripts/benchmark/trainer-benchmark.py --base-cmd \
' \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-base --output_dir output_dir \
--do_train --label_smoothing 0.1 --logging_strategy no --save_strategy no --per_device_train_batch_size 32 \
--max_source_length 512 --max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: "  --warmup_steps 50 \
--max_train_samples 20000 --dataloader_num_workers 2 \
' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--optim adamw_hf|--optim adamw_torch|--optim adafactor|--optim adamw_apex_fused' \
--report-metric-keys train_loss --base-variation '--optim adamw_torch'

# fp16 - just add --fp16 to base-cmd

# bf16 - just add --bf16 to base-cmd
stas00 commented 2 years ago

combining winning strategies

Now let's combine the winning strategies from each individual benchmark above and compare with the baseline:

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0 92.15 0 2.21
--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 --bf16 267.21 190 2.62

Getting an almost 3x improvement in speed!

CUDA_VISIBLE_DEVICES=3 python \
../transformers-stas/scripts/benchmark/trainer-benchmark.py --base-cmd \
' \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-base --output_dir output_dir \
--do_train --label_smoothing 0.1 --logging_strategy no --save_strategy no --per_device_train_batch_size 32 \
--max_source_length 512 --max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: "  --warmup_steps 50 \
--max_train_samples 20000 --dataloader_num_workers 2 \
' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0|--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 --bf16' \
--report-metric-keys train_loss --base-variation '--optim adamw_torch'
stas00 commented 2 years ago

RTX-3090 vs A100

In all the benchmarks above I was making the batch size bigger and run more samples comparative to the same RTX-3090 benchmarks as A100 40GB card can handle more than RTX-3090 24GB, but let's compare now the 2 using the same config. So we will have RTX-3090 fully loaded, but A100 will be only partially loaded.

Also each card is running on a different machines so there is a bit of hardware difference as well.

A100

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0 85.99 0 2.16
--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 --bf16 153.72 79 2.42

RTX-3090

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0 88.94 0 2.16
--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 --bf16 173.15 95 2.42

Observations:

Same software was used for both setups:

transformers: 4.16.0.dev0
torch       : 1.10.1
cuda        : 11.3
python      : 3.8.12
CUDA_VISIBLE_DEVICES=0 python \
/hf/transformers-trainer-benchmark/scripts/benchmark/trainer-benchmark.py \
--base-cmd \
' \
examples/pytorch/translation/run_translation.py --model_name_or_path t5-base --output_dir output_dir \
--do_train --label_smoothing 0.1 --logging_strategy no --save_strategy no --per_device_train_batch_size 16 \
--max_source_length 512 --max_target_length 512 --num_train_epochs 1 --overwrite_output_dir \
--source_lang en --target_lang ro --dataset_name wmt16 --dataset_config "ro-en" \
--source_prefix "translate English to Romanian: "  --warmup_steps 50 \
--max_train_samples 20000 --dataloader_num_workers 2 \
' \
--target-metric-key train_samples_per_second --repeat-times 1 --variations \
'--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0|--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 --bf16' \
--report-metric-keys train_loss

I thought that perhaps this had to do with bf16, so I re-did the same with --fp16 instead of --bf16, but the outcome is similar that RTX-3090 appears to be faster on the same benchmark:

A100

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0 85.89 0 2.16
--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 --fp16 144.20 68 2.39

RTX-3090

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0 88.95 0 2.16
--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 --fp16 168.28 89 2.39

Still not good for A100. Let's try w/o tf32:

A100

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0 86.35 0 2.16
--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 0 --fp16 156.16 81 2.39

RTX-3090

Variation Train
samples
per
second
Diff
%
Train
loss
--optim adamw_torch --gradient_accumulation_steps 1 --tf32 0 88.87 0 2.16
--optim adamw_apex_fused --gradient_accumulation_steps 8 --tf32 0 --fp16 167.36 88 2.39

This is better for A100. So tf32 was making things worse here for some reason.

Eddie Yan explained the reason for RTX-3090 being faster:

When there is underutilization on GPUs of similar architectures, it may come down to clock rates, and 3090 does have faster peak clock rates than A100.

eqy commented 2 years ago

@stas00 Another interesting issue that I found regarding batch size is that it is an important parameter when the model is mostly in fp32 and relies on autocast to dispatch to fp16 or bf16. I believe this is because of the overhead of casting back-and-forth can dominate the total runtime compared to the actual kernel/operator.

Consider the following microbenchmark:

import time
import torch

from torch.cuda.amp import autocast

def bench(dtype, dim, shape, auto=True):
    linear = torch.nn.Linear(shape[-1], dim, device='cuda')
    inp = torch.randn(shape, device='cuda')
    ctx_manager = None
    if not auto:
        inp = inp.to(dtype)
        linear = linear.to(dtype)
    else:
        ctx_manager = autocast(dtype=dtype)

    def run(inp, layer, ctx_manager):
        if ctx_manager is not None:
            with ctx_manager:
                layer(inp)
        else:
            layer(inp)

    run(inp, linear, ctx_manager)
    torch.cuda.synchronize()
    t1 = time.time()
    for i in range(1000):
        run(inp, linear, ctx_manager)
    torch.cuda.synchronize()
    t2 = time.time()
    return t2 - t1

if __name__ == '__main__':
    hidden_dim = 1024
    for auto in (True, False):
        print(f"autocast: {auto}")
        for batch_size in (8, 16, 32, 64):
            shape = [batch_size, 56, hidden_dim]
            times = list()
            for dtype in (torch.float32, torch.float16, torch.bfloat16):
                times.append(bench(dtype, hidden_dim, shape, auto))
            print(f"batch_size: {batch_size} fp32 {times[0]:3f} fp16 {times[1]:3f} bf16 {times[2]:3f}")
            print(f"speedup: fp32 {(times[0]/times[0]):3f} fp16 {(times[0]/times[1]):3f} bf16 {(times[0]/times[2]):3f}")

I get the following times on A6000 (similar architecture to 3090):

autocast: True
batch_size: 8 fp32 0.040668 fp16 0.061515 bf16 0.060857
speedup: fp32 1.000000 fp16 0.661102 bf16 0.668253
batch_size: 16 fp32 0.050241 fp16 0.061965 bf16 0.061339
speedup: fp32 1.000000 fp16 0.810793 bf16 0.819065
batch_size: 32 fp32 0.109936 fp16 0.089657 bf16 0.091546
speedup: fp32 1.000000 fp16 1.226184 bf16 1.200876
batch_size: 64 fp32 0.189083 fp16 0.150391 bf16 0.150227
speedup: fp32 1.000000 fp16 1.257275 bf16 1.258648
autocast: False
batch_size: 8 fp32 0.038590 fp16 0.031647 bf16 0.030893
speedup: fp32 1.000000 fp16 1.219398 bf16 1.249145
batch_size: 16 fp32 0.049446 fp16 0.032320 bf16 0.031509
speedup: fp32 1.000000 fp16 1.529887 bf16 1.569281
batch_size: 32 fp32 0.111689 fp16 0.056600 bf16 0.060192
speedup: fp32 1.000000 fp16 1.973323 bf16 1.855555
batch_size: 64 fp32 0.190082 fp16 0.103095 bf16 0.104311
speedup: fp32 1.000000 fp16 1.843766 bf16 1.822261
stas00 commented 2 years ago

Thank you, @eqy.

update: edited out the original note on casting back, since the explicit casting is not being measured

Added a nicely formatted table output so it's much easier to analyze. Updated script attached: bench.txt

On RTX-3090 I get:

Autocast: True Results: bs torch.float32 torch.float16 torch.bfloat16
8 0.057 0.070 0.070
16 0.082 0.082 0.070
32 0.169 0.103 0.119
64 0.267 0.190 0.191
Speedup: bs torch.float32 torch.float16 torch.bfloat16
8 1.000 0.810 0.818
16 1.000 0.997 1.179
32 1.000 1.639 1.421
64 1.000 1.411 1.398
Autocast: False Results: bs torch.float32 torch.float16 torch.bfloat16
8 0.052 0.040 0.040
16 0.082 0.045 0.045
32 0.170 0.073 0.090
64 0.268 0.143 0.148
Speedup: bs torch.float32 torch.float16 torch.bfloat16
8 1.000 1.320 1.306
16 1.000 1.849 1.820
32 1.000 2.338 1.895
64 1.000 1.866 1.814
eqy commented 2 years ago

I believe there are "speed-of-light" cases where the cast-back wouldn't be necessary, though this may not be possible for the architectures we're interested in. Here, I think the big picture is that once the batch-size falls below a certain amount, the "building-block" operations like GEMMs will be slower in reduced precision vs. fp32 when casts are needed.

stas00 commented 2 years ago

why do you think bs=32 is an oddball relative to other bs for speedup? in both cases w/ and w/o amp its relatively faster for bf16 and fp16 then bs=64, and much more significantly for fp16. One would expect 8 < 16 < 32 < 64, but here it is 8 < 16 < 64< 32.

so actual results are proportionally in line, but the speed ups aren't.

eqy commented 2 years ago

That's interesting, I didn't see quite so dramatic results on an A100 (80GB), 2 runs:

Autocast: True

Results: bs torch.float32 torch.float16 torch.bfloat16
8 0.062 0.086 0.077
16 0.043 0.089 0.085
32 0.073 0.084 0.084
64 0.119 0.101 0.112
Speedup: bs torch.float32 torch.float16 torch.bfloat16
8 1.000 0.714 0.805
16 1.000 0.486 0.508
32 1.000 0.865 0.871
64 1.000 1.181 1.058

Autocast: False

Results: bs torch.float32 torch.float16 torch.bfloat16
8 0.045 0.049 0.040 16 0.041 0.048 0.047
32 0.073 0.046 0.044
64 0.120 0.063 0.076

Speedup:

bs torch.float32 torch.float16 torch.bfloat16
8 1.000 0.908 1.129
16 1.000 0.855 0.873
32 1.000 1.570 1.638
64 1.000 1.913 1.580

Autocast: True

Results: bs torch.float32 torch.float16 torch.bfloat16
8 0.062 0.086 0.077
16 0.059 0.089 0.085
32 0.073 0.084 0.084
64 0.119 0.101 0.114
Speedup: bs torch.float32 torch.float16 torch.bfloat16
8 1.000 0.720 0.802
16 1.000 0.660 0.691
32 1.000 0.871 0.866
64 1.000 1.182 1.051

Autocast: False

Results: bs torch.float32 torch.float16 torch.bfloat16
8 0.044 0.048 0.041
16 0.041 0.047 0.048
32 0.073 0.045 0.047
64 0.120 0.063 0.077
Speedup: bs torch.float32 torch.float16 torch.bfloat16
8 1.000 0.929 1.094
16 1.000 0.883 0.861
32 1.000 1.615 1.541
64 1.000 1.907 1.562
pratikchhapolika commented 2 years ago

🖥 Benchmarking transformers w/ HF Trainer on A100 40GB

We are going to use a special benchmarking tool that will do all the work for us. #14934

This is the index post and specific benchmarks are in their own posts below:

This is the index post and specific benchmarks are in their own posts below:

  1. fp16 vs bf16 vs tf32 vs fp32
  2. gradient accumulation steps
  3. batch size
  4. gradient checkpointing
  5. optimizers
  6. combining winning strategies ~3x speed improvement!
  7. RTX-3090 vs A100

Note that each benchmark was run only once, so multiple runs and averaging is probably going to give slightly different results. The purpose here though is to see relative differences roughly and not try to give an exact number.

See also the same benchmarks for RTX-3090

Is all benchmarking done on A100 ("NVIDIA_TESLA_A100") single GPU? Can you also include CUDA memory required Vs Data points for training Vs No. of GPU's.

stas00 commented 2 years ago

Is all benchmarking done on A100 ("NVIDIA_TESLA_A100") single GPU?

Yes.

Can you also include CUDA memory required Vs Data points for training Vs No. of GPU's.

I don't understand your question.

pratikchhapolika commented 2 years ago

Is all benchmarking done on A100 ("NVIDIA_TESLA_A100") single GPU?

Yes.

Can you also include CUDA memory required Vs Data points for training Vs No. of GPU's.

I don't understand your question.

On how many data points and epochs is it benchmarked on with Single GPU?

pratikchhapolika commented 2 years ago

Is all benchmarking done on A100 ("NVIDIA_TESLA_A100") single GPU?

Yes.

Can you also include CUDA memory required Vs Data points for training Vs No. of GPU's.

I don't understand your question.

On how many data points and epochs is it benchmarked on with Single GPU?

I get error with 4 GPU's, 20 epochs on A100 with 700000 data points

python -m torch.distributed.launch --nproc_per_node 4 train.py --gradient_accumulation_steps 8 --per_device_train_batch_size 8 --optim adamw_hf --tf32 --bf16"])

'Traceback (most recent call last):\n', ' File "train.py", line 160, in <module>\n trainer.train()\n', ' File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 1398, in train\n tr_loss_step = self.training_step(model, inputs)\n', ' File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 1994, in training_step\n self.scaler.scale(loss).backward()\n', ' File "/opt/conda/lib/python3.7/site-packages/torch/_tensor.py", line 363, in backward\n torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)\n', ' File "/opt/conda/lib/python3.7/site-packages/torch/autograd/__init__.py", line 175, in backward\n allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass\n', 'RuntimeError: CUDA out of memory. Tried to allocate 3.82 GiB (GPU 0; 39.59 GiB total capacity; 17.07 GiB already allocated; 2.75 GiB free; 21.43 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n'

stas00 commented 2 years ago

On how many data points and epochs is it benchmarked on with Single GPU?

It's defined by --max_train_samples

To your last OOM comment - please let's not derail this Benchmark Issue. If you want to discuss an unrelated question please open a new issue. Best to delete it from here and post in another Issue. Thank you.