intel / torch-xpu-ops

Apache License 2.0
30 stars 21 forks source link

batch_normalization: Introduce vectorization optimization in the batch norm elementwise kernel. #933

Closed xytintel closed 1 week ago

xytintel commented 2 months ago

Due to performance issues with the low-precision data type implementation of group stride loops on PVC (jira: PYTORCHDGQ-5162), partial vectorization optimization is used.

EikanWang commented 2 weeks ago

@xytintel , may I know if this PR is still available?

xytintel commented 2 weeks ago

@xytintel , may I know if this PR is still available?

Yes, In the process of merging it.

xytintel commented 1 week ago

Group stride kernel: 12.493ms Vectorized kernel: 10.024ms

import torch
a = torch.rand(2, 256, 128, 128).xpu().bfloat16()
bn = torch.nn.BatchNorm2d(256).xpu().bfloat16()
prof_xpu = torch.profiler.profile(
    activities=[
    torch.profiler.ProfilerActivity.CPU,
    torch.profiler.ProfilerActivity.XPU],
)
with prof_xpu:
    for i in range(120):
        output = bn(a)
print(prof_xpu.key_averages().table(sort_by="self_xpu_time_total", row_limit=100000))
print(output.dtype)