Open rnwang04 opened 1 year ago
are you able to get the similar result with a standalone bmm op?
Yes, I can get the similar result with a standalone bmm op. I found that if just loop bmm with same input, then fp16 is much faster than fp32. However, if in the loop, each input has different shape, then fp16 becomes much slower.
import torch
import os
import time
import intel_extension_for_pytorch as ipex
st1 = time.time()
for i in range(32):
a1 = torch.rand(32, 1, i).to('xpu')
b1 = torch.rand(32, i, 100).to('xpu')
o = torch.bmm(a1, b1)
a2 = torch.rand(32, 1, 100).to('xpu')
b2 = torch.rand(32, 100, i).to('xpu')
o = torch.bmm(a1, b1)
st2 = time.time()
print(f"bmm fp32 cost : {(st2-st1) * 1000}ms")
st1 = time.time()
for i in range(1000):
a1 = torch.rand(32, 1, i).half().to('xpu')
b1 = torch.rand(32, i, 100).half().to('xpu')
o = torch.bmm(a1, b1)
a2 = torch.rand(32, 1, 100).half().to('xpu')
b2 = torch.rand(32, 100, i).half().to('xpu')
o = torch.bmm(a1, b1)
st2 = time.time()
print(f"bmm fp16 cost : {(st2-st1) * 1000}ms")
a132_list = [torch.rand(32, 1, i).to('xpu') for i in range(32)]
a232_list = [torch.rand(32, 1, 100).to('xpu') for i in range(32)]
b132_list = [torch.rand(32, i, 100).to('xpu') for i in range(32)]
b232_list = [torch.rand(32, 100, i).to('xpu') for i in range(32)]
st1 = time.time()
for i in range(32):
o = torch.bmm(a132_list[i], b132_list[i])
o = torch.bmm(a232_list[i], b232_list[i])
st2 = time.time()
print(f"bmm fp32 cost : {(st2-st1) * 1000}ms")
a116_list = [torch.rand(32, 1, i).half().to('xpu') for i in range(32)]
a216_list = [torch.rand(32, 1, 100).half().to('xpu') for _ in range(32)]
b116_list = [torch.rand(32, i, 100).half().to('xpu') for _ in range(32)]
b216_list = [torch.rand(32, 100, i).half().to('xpu') for i in range(32)]
st1 = time.time()
for i in range(32):
o = torch.bmm(a116_list[i], b116_list[i])
o = torch.bmm(a216_list[i], b216_list[i])
st2 = time.time()
print(f"bmm fp16 cost : {(st2-st1) * 1000}ms")
output is:
bmm fp32 cost : 146.5318202972412ms
bmm fp16 cost : 12086.34614944458ms
bmm fp32 cost : 15.504598617553711ms
bmm fp16 cost : 88.47284317016602ms
@rnwang04 Thanks for the benchmarking. I have two comments:
torch.rand
time. Can you only count in the bmm time in your benchmark?@jgong5 Thanks for the reply! I have updated test script based on your comment, now I remove the rand time and added warmup. Now the time of aten::bmm seems almost the same for fp16 & fp32. This is my new script:
import torch
import os
import time
import intel_extension_for_pytorch as ipex
a132_list = [torch.rand(32, 1, i).to('xpu') for i in range(32)]
a232_list = [torch.rand(32, 1, 100).to('xpu') for _ in range(32)]
b132_list = [torch.rand(32, i, 100).to('xpu') for i in range(32)]
b232_list = [torch.rand(32, 100, i).to('xpu') for i in range(32)]
#warmup for fp32
for i in range(32):
o = torch.bmm(a132_list[i], b132_list[i])
o = torch.bmm(a232_list[i], b232_list[i])
torch.xpu.synchronize()
st1 = time.time()
for i in range(32):
o = torch.bmm(a132_list[i], b132_list[i])
o = torch.bmm(a232_list[i], b232_list[i])
torch.xpu.synchronize()
st2 = time.time()
print(f"bmm fp32 cost : {(st2-st1) * 1000}ms")
a116_list = [torch.rand(32, 1, i).half().to('xpu') for i in range(32)]
a216_list = [torch.rand(32, 1, 100).half().to('xpu') for _ in range(32)]
b116_list = [torch.rand(32, i, 100).half().to('xpu') for i in range(32)]
b216_list = [torch.rand(32, 100, i).half().to('xpu') for i in range(32)]
# warmup for fp16
for i in range(32):
o = torch.bmm(a116_list[i], b116_list[i])
o = torch.bmm(a216_list[i], b216_list[i])
torch.xpu.synchronize()
st1 = time.time()
for i in range(32):
o = torch.bmm(a116_list[i], b116_list[i])
o = torch.bmm(a216_list[i], b216_list[i])
torch.xpu.synchronize()
st2 = time.time()
print(f"bmm fp16 cost : {(st2-st1) * 1000}ms")
and this is my output:
bmm fp32 cost : 1.3713836669921875ms
bmm fp16 cost : 1.3885498046875ms
Now my question is why this op has different behaviour in unittest and in llm? I think in above llama-3b profile, I have alreadly done the warmup, but as you can see, aten::bmm is much slower in fp16.
Now my question is why this op has different behaviour in unittest and in llm? I think in above llama-3b profile, I have alreadly done the warmup, but as you can see, aten::bmm is much slower in fp16.
I guess that is due to the dynamic shape? A quick test could be to feed the model with all the sequence length during warmup. I don't mean it is the best practice but just wanted to confirm if that is the problem.
Describe the bug
Hi all, when I do some profile with open-llama-3b on Arc A770, I found in float16, aten::bmm becomes extramely slower compared to float32 (111.4ms vs 22.5ms). I wonder is this behavior normal or this is maybe an issue needed to be fixed?
reproduce code:
fp32 profile result:
fp16 profile result:
Other op (mul / add / ..) becomes slower in fp16 too, but aten::bmm is the most obvious.
Versions