Closed xytintel closed 1 week ago
@xytintel , may I know if this PR is still available?
@xytintel , may I know if this PR is still available?
Yes, In the process of merging it.
Group stride kernel: 12.493ms Vectorized kernel: 10.024ms
import torch
a = torch.rand(2, 256, 128, 128).xpu().bfloat16()
bn = torch.nn.BatchNorm2d(256).xpu().bfloat16()
prof_xpu = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.XPU],
)
with prof_xpu:
for i in range(120):
output = bn(a)
print(prof_xpu.key_averages().table(sort_by="self_xpu_time_total", row_limit=100000))
print(output.dtype)
Due to performance issues with the low-precision data type implementation of group stride loops on PVC (jira: PYTORCHDGQ-5162), partial vectorization optimization is used.