NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.32k stars 1.38k forks source link

Could torch.einsum gain speed boost ? #394

Closed fyubang closed 5 years ago

fyubang commented 5 years ago

I am trying to fine tune xlnet and found that the memory was half, but it was slower than fp32(even when I double the batch size).

Environment: v100, cuda 10.0, torch 1.1

The environment is ok, because I tried bert + fp16 and it was much faster than fp32. I thought it is the problem of torch.einsum, but I am not that sure.

ptrblck commented 5 years ago

Hi @fyubang,

could you post a link to the repo you are using so that we can have a look?

fyubang commented 5 years ago

Hi @fyubang,

could you post a link to the repo you are using so that we can have a look?

Sorry for forgetting about the link, I used the code here: https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py

fyubang commented 5 years ago

Hi @ptrblck, I tried the new repo of Huggingface, it did not work either. https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py

ptrblck commented 5 years ago

Thanks for the link, @fyubang. We'll take a look at it.

ptrblck commented 5 years ago

We tried to compare the performance between a FP32 run and an amp run using opt_level='O1'. For this, we've cloned the current repo from @huggingface and used the command as given here for the FP32 run:

python -m torch.distributed.launch --nproc_per_node=8 ./examples/run_squad.py \
    --model_type bert \
    --model_name_or_path bert-large-uncased-whole-word-masking \
    --do_train \
    --do_eval \
    --do_lower_case \
    --train_file $SQUAD_DIR/train-v1.1.json \
    --predict_file $SQUAD_DIR/dev-v1.1.json \
    --learning_rate 3e-5 \
    --num_train_epochs 2 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --output_dir ../models/wwm_uncased_finetuned_squad/ \
    --per_gpu_eval_batch_size=3   \
    --per_gpu_train_batch_size=3 

Using 8 V100 GPUs (each with 32GB), we could achieve a mean speed of ~2.65 iterations/second.

However, supporting the --fp16 argument to the same command, apex raises an error, since DDP is being initialized before amp.initialize was called. Did you observe the same error?

After changing the order of initialization, we could successfully run the script on the same machine achieving ~3.70 iterations/second, which seems reasonable.

By "it did not work either", are you referring to the raised error or to a slower run using amp?

CC @huggingface Is this a known issue and would you be interested in a fix?

ptrblck commented 5 years ago

I rerun the test using the xlnet:

python -m torch.distributed.launch --nproc_per_node=8 ./examples/run_squad.py \
    --model_type xlnet \
    --model_name_or_path xlnet-large-cased \
    --do_train \
    --do_eval \
    --do_lower_case \
    --train_file $SQUAD_DIR/train-v1.1.json \
    --predict_file $SQUAD_DIR/dev-v1.1.json \
    --learning_rate 3e-5 \
    --num_train_epochs 2 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --output_dir ../models/wwm_uncased_finetuned_squad/ \
    --per_gpu_eval_batch_size=3   \
    --per_gpu_train_batch_size=3 

and got the following numbers: FP32: ~1.35 iterations/second AMP O1: ~1.44 iterations/second

The performance benefit is indeed smaller and worth having a closer look at.

fyubang commented 5 years ago

@ptrblck Thanks for your reply. I got similar result with you. I thought the reason is that the author had a heavy use of torch.einsum, like: torch.einsum('ibnd,jbnd->ijbn', a, b)

I tried to replace it by:

a_tmp = a.permute(1,2,0,3)
b_tmp = b.permute(1,2,3,0)
res = a_tmp.matmul(b_tmp)
res = res.permute(2,3,0,1)

but it became even slower than torch.einsum.

fyubang commented 5 years ago

@ptrblck I tested the speed of computation of matmul when the shapes of input are (a,b,c,d) and (a,b,d,e). I found that fp16 is much much slower than fp32 (like 1: 20). It may be the reason why fp16 was slower.

ptrblck commented 5 years ago

@fyubang Note that the shapes for GEMMs should be multiples of 8 as explained in our pinned topic.

Here is a small benchmark using 1) shapes of factors of 8 and 2) missing this condition slightly:

# 1)
I, J, K = 64, 1024, 1024
A = torch.randn(I, J, device='cuda', dtype=torch.half)
B = torch.randn(J, K, device='cuda', dtype=torch.half)

nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
    C = torch.matmul(A, B)
torch.cuda.synchronize()
t1 = time.time()
print('{:.3f}us per iteration'.format((t1 - t0) / nb_iters * 1e6))

> 16.043us per iteration

# 2)
I, J, K = 63, 1023, 1023
A = torch.randn(I, J, device='cuda', dtype=torch.half)
B = torch.randn(J, K, device='cuda', dtype=torch.half)

nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
    C = torch.matmul(A, B)
torch.cuda.synchronize()
t1 = time.time()
print('{:.3f}us per iteration'.format((t1 - t0) / nb_iters * 1e6))

> 39.476us per iteration

Could this also be the reason for the minor speedup in the XLNET?

fyubang commented 5 years ago

@ptrblck Thanks for your reply. In fact, when I tryied (i, j) matmul (j,k), it can always have a speed boost. However, the problem is when I tried (a,b,c) matmul (a, c, d), it will not get accelerated. In addition, here is the config of xlnet:

{
  "attn_type": "bi",
  "bi_data": false,
  "clamp_len": -1,
  "d_head": 64,
  "d_inner": 4096,
  "d_model": 1024,
  "dropatt": 0.1,
  "dropout": 0.1,
  "ff_activation": "gelu",
  "init": "normal",
  "init_range": 0.1,
  "init_std": 0.02,
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "mem_len": null,
  "n_head": 16,
  "n_layer": 24,
  "n_token": 32000,
  "reuse_len": null,
  "same_length": false,
  "untie_r": true
}

Since they are all multiples of 8, I think it is not the problem of "multiples of 8".

ptrblck commented 5 years ago

Do you mean by "In fact, when I tryied (i, j) matmul (j,k), it can always have a speed boost", that each FP16 matmul in this form will be faster than the corresponding FP32 matmul regardless of the input shapes? This sounds strange to me, as I'll get similar FP16 (non-x8-shaped) timings to FP32 ones, while x8-shaped FP16 matmuls yields a speedup. I've added also another dimension and also get a speedup for x8-shaped FP16.

Could you try to add some warmup iterations before the actual timings? The first measured time might be a bit biased.

Thanks for the information about xlnet. We'll look into it.

fyubang commented 5 years ago

@ptrblck Thanks for your reply. For the first quesiton: Yes, I mean it, but maybe I used the shape like 20, 60 instead of 63.

For the second quesiton: Could you try this code (they are all x8-shaped) and check if you can still have a speedup for fp16 ?

import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
from time import time
# 1) fp32
a = torch.empty(24,32,40,48, dtype=torch.float32).to('cuda')
b = torch.empty(64,32,40,48, dtype=torch.float32).to('cuda')
c = torch.empty(40,80,24, dtype=torch.float32).to('cuda')
d = torch.empty(40,24,16, dtype=torch.float32).to('cuda')

torch.cuda.synchronize()
st = time()
for _ in range(1000):
    c.matmul(d)
torch.cuda.synchronize()
print(time()-st)

torch.cuda.synchronize()
st = time()
for _ in range(1000):
    torch.einsum('ibnd,jbnd->ijbn', a, b)
torch.cuda.synchronize()
print(time()-st)

# 2) fp16
a = torch.empty(24,32,40,48, dtype=torch.float16).to('cuda')
b = torch.empty(64,32,40,48, dtype=torch.float16).to('cuda')
c = torch.empty(40,80,24, dtype=torch.float16).to('cuda')
d = torch.empty(40,24,16, dtype=torch.float16).to('cuda')

torch.cuda.synchronize()
st = time()
for _ in range(1000):
    torch.matmul(c,d)
torch.cuda.synchronize()
print(time()-st)

torch.cuda.synchronize()
st = time()
for _ in range(1000):
    torch.einsum('ibnd,jbnd->ijbn', a, b)
torch.cuda.synchronize()
print(time()-st)

my result is:

0.028162240982055664
0.10057997703552246
0.38828039169311523
11.749611377716064
ptrblck commented 5 years ago

There are my results for your calculations on a TITAN V:

0.017162799835205078
0.09859037399291992
0.015858173370361328
0.042925119400024414
fyubang commented 5 years ago

@ptrblck Thanks for your reply. It seems torch.einsum does have a speedup. I will double check again.

ngimel commented 5 years ago

Closed in favor of https://github.com/pytorch/pytorch/issues/23061, this does not seem to be amp-specific.