intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform
Apache License 2.0
1.64k stars 254 forks source link

Low Mixed Precision Performance #296

Open fredlarochelle opened 1 year ago

fredlarochelle commented 1 year ago

I am encountering some strange performance behavior on the A770. For example, taking the CIFAR-10 example in the documentation.

Using FP32, I get around 5.75s per epoch and, using BF16, I get around 6.2s per epoch. I also get the same exact performance with and without ipex.optimize().

Also, when I compare the performance with a Tesla T4 on Colab, in FP32, it runs each epoch in around 1s and, for FP16, around 0.25s. Wayy faster and the A770 has technically better specs...

Are the XMX engines being used on Arc GPUs? Yes #258?

Dunno, if it might be related, but I get the following warnings when running the example (EDIT I started going through the code in the repo, those warnings are not related to the current issue):

[/home/fred/.local/lib/python3.10/site-packages/intel_extension_for_pytorch/frontend.py:447](https://vscode-remote+ssh-002dremote-002b192-002e168-002e0-002e124.vscode-resource.vscode-cdn.net/home/fred/.local/lib/python3.10/site-packages/intel_extension_for_pytorch/frontend.py:447): UserWarning: For XPU device, the split master weight is unsupported for now, so temp to disable it
  warnings.warn("For XPU device, the split master weight is unsupported for now, so temp to disable it")
[/home/fred/.local/lib/python3.10/site-packages/intel_extension_for_pytorch/frontend.py:457](https://vscode-remote+ssh-002dremote-002b192-002e168-002e0-002e124.vscode-resource.vscode-cdn.net/home/fred/.local/lib/python3.10/site-packages/intel_extension_for_pytorch/frontend.py:457): UserWarning: For XPU device to save valuable device memory, temp to do optimization on inplaced model, so                     make inplace to be true
  warnings.warn(
[/home/fred/.local/lib/python3.10/site-packages/intel_extension_for_pytorch/frontend.py:464](https://vscode-remote+ssh-002dremote-002b192-002e168-002e0-002e124.vscode-resource.vscode-cdn.net/home/fred/.local/lib/python3.10/site-packages/intel_extension_for_pytorch/frontend.py:464): UserWarning: For XPU, the weight prepack and sample input are disabled. The onednn layout                     is automatically chosen to use
  warnings.warn(

Ubuntu 22.04 with 1.13.10+xpu.

fredlarochelle commented 1 year ago

I did some more digging around. It doesn't seem to be a problem with mixed-precision per se, more a performance problem in general.

Using the profiler, for FP32, I found that the aten: convolution_backward_overrideable operator is the most problematic. Just that operator takes 4.7s (80% of the time on the XPU) and, also, aten::convolution_overrideable, takes 0.9s (16%).

But if we compare again with a Tesla T4 on Colab, equivalent operators are faster on the A770, except for those two that are problematics.

Out of curiosity, I tried to run the same code on CPU and it took only 7.5s (compared to 5.75s on XPU), but the most surprising is that the torch_ipex:: convolution_backward operator on CPU was way faster, it took only 3.3s (compared to 4.7s on XPU)!

asirvaiy commented 1 year ago

Hi @fredlarochelle , Thanks for reporting. I have already faced this issue. We are working on it. A few things to note and wanted to check. always keep a few iterations for warm-up(at least 2). Just wanted to check, are you using torch.xpu.synchronize() in xpu and torch.cuda.synchronize() in cuda? When we calculate the latency using torch.xpu.synchronize() before start and end time, It really shows good numbers. Even though weights update part for the batch is taking more time due to aten: convolution_backward_overrideable

        torch.xpu.synchronize()
        btime = time()

        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        torch.xpu.synchronize()
        etime = time()
asirvaiy commented 1 year ago

Just checked, it's nothing to do with synchronize() and seems like not even aten: convolution_backward_overrideable

Each data batch from CPU to XPU is taking a significant amount of time. due to aten::copy_ Here : data= data.to("xpu:1")

fredlarochelle commented 1 year ago

Using torch.xpu.synchronize(), I get that it's a bit slower actually (around 6.2s, compared to 5.75s) (and that's after 3-4 iterations, where it's pretty stable).

Yeah, aten::copy_ is pretty slow too at around 0.9s (4x slower than Colab in my case, but that can depend on the CPU, memory, ...), but the main culprit is still aten: convolution_backward_overrideable.

jingxu10 commented 1 year ago

That sample code just means to show usage of IPEX APIs, didn't mean to compare performance. Would you try the training example with a large batch size?

BA8F0D39 commented 1 year ago

@jingxu10 @fredlarochelle

As @fredlarochelle said aten::copy_ is extremly slow

The bottleneck seems to be at GPU to GPU transfer

Example code testing GPU to GPU memory copy

import torch
import torchvision

import intel_extension_for_pytorch as ipex

import time

sizes = [1000, 1000,1000,1000,1000,2000, 3000, 10000, 20000,30000,40000,42000,43000,44000,45000,]

for size in sizes:

    array0 = torch.rand(size, size, dtype=torch.bfloat16).to("xpu")
    array1 = torch.rand(size, size, dtype=torch.bfloat16).to("xpu")

    torch.xpu.synchronize()
    start = time.time()
    array0 = torch.clone(array1)
    torch.xpu.synchronize()
    end = time.time()
    transferrate = (size*size*16)/(end - start)
    datasize = (size*size*16)
    print("==========")
    print("Transfering " + str(datasize/8E9) + " GB")
    print("Bandwidth " + str(transferrate/8E9) + " GB/s")
    print("==========")

    torch.xpu.empty_cache()

It bottlenecks at 100 GB/s max transfer rate when the A770 bandwidth is 512 GB/s. Also, the most important thing is, transferring low amounts of data is extremely slow because small batches only achieves 30 GB/s which is 6% utilization of the maximum theoretical bandwidth of 512 GB/s.

==========
Transfering 0.002 GB
Bandwidth 29.746836879432625 GB/s
==========
==========
Transfering 0.008 GB
Bandwidth 50.45779248120301 GB/s
==========
==========
Transfering 0.018 GB
Bandwidth 71.29128611898017 GB/s
==========
==========
Transfering 0.2 GB
Bandwidth 80.71401905128451 GB/s
==========
==========
Transfering 0.8 GB
Bandwidth 91.80419151846785 GB/s
==========
==========
Transfering 1.8 GB
Bandwidth 102.16166711772665 GB/s
==========
==========
Transfering 3.2 GB
Bandwidth 103.13888713854288 GB/s
==========
==========
Transfering 3.528 GB
Bandwidth 102.9577837521917 GB/s
==========
BA8F0D39 commented 1 year ago

Also matrix multiplication seems to be memory bandwidth limited

import torch
import torchvision

import intel_extension_for_pytorch as ipex

import time

sizes = [100,200,300,400,500,700,800,900,1000,2000, 3000,4000,5000,6000,7000,8000, 10000, 20000, 30000, ]

for i in range(10):
    for size in sizes:

        w = torch.rand(size, size, dtype=torch.bfloat16).to("xpu")
        x = torch.rand(size, size, dtype=torch.bfloat16).to("xpu")

        ops = 2*size**3   

        torch.xpu.synchronize()
        start = time.time()
        y = torch.mm(w,x)
        torch.xpu.synchronize()
        end = time.time()
        duration = (end - start)
        tflops = ops / duration / 10**12
        print("==========")
        print("Size " + str(size) + " by " + str(size))
        print("Performance " + str(tflops) + " TFLOPS")
        print("==========")

        torch.xpu.empty_cache()

TFLOPS in matrix multiplication steadily increases until 30000 by 30000 matrix and suddenly decreases. I am guessing matrix multiplication is limited by memory bandwidth and there are cache misses when fitting large matrices into the VRAM

==========
Size 100 by 100
Performance 0.011011431345646439 TFLOPS
==========
==========
Size 200 by 200
Performance 0.1333487885258964 TFLOPS
==========
==========
Size 300 by 300
Performance 0.4023397306761566 TFLOPS
==========
==========
Size 400 by 400
Performance 1.0155299684848484 TFLOPS
==========
==========
Size 500 by 500
Performance 0.9817501630740393 TFLOPS
==========
==========
Size 700 by 700
Performance 4.1549672471676296 TFLOPS
==========
==========
Size 800 by 800
Performance 4.508700568739496 TFLOPS
==========
==========
Size 900 by 900
Performance 5.3193192739425585 TFLOPS
==========
==========
Size 1000 by 1000
Performance 7.601463006346328 TFLOPS
==========
==========
Size 2000 by 2000
Performance 40.637242146577826 TFLOPS
==========
==========
Size 3000 by 3000
Performance 50.300903434917814 TFLOPS
==========
==========
Size 4000 by 4000
Performance 52.954898208148364 TFLOPS
==========
==========
Size 5000 by 5000
Performance 68.8153808348648 TFLOPS
==========
==========
Size 6000 by 6000
Performance 77.19263487094713 TFLOPS
==========
==========
Size 7000 by 7000
Performance 81.02188181087017 TFLOPS
==========
==========
Size 8000 by 8000
Performance 87.00238762927698 TFLOPS
==========
==========
Size 10000 by 10000
Performance 89.37010376841857 TFLOPS
==========
==========
Size 20000 by 20000
Performance 101.03734560466859 TFLOPS
==========
==========
Size 30000 by 30000
Performance 74.42093049111077 TFLOPS
==========
fredlarochelle commented 1 year ago

@BA8F0D39 Quick tip, when possible, creating the matrix on XPU will make everything run wayyy faster.

# Instead of
w = torch.rand(size, size, dtype=torch.bfloat16).to('xpu')
x = torch.rand(size, size, dtype=torch.bfloat16).to('xpu')

# Do this
w = torch.rand(size, size, dtype=torch.bfloat16, device='xpu')
x = torch.rand(size, size, dtype=torch.bfloat16, device='xpu')

For me, doing this, your TFLOPS program runs about 30x faster.

BA8F0D39 commented 1 year ago

@fredlarochelle Thanks for the tip. I was wrong to think torch.xpu.synchronize() synchronizes between CPU and GPU

Also, does your A770 GPU allow you to allocate bf16 memory above 8GB? My code crashes on the A770 if you try to allocate bf16 memory above 8GB. I have a A770 16 GB card.

fredlarochelle commented 1 year ago

No, like with CUDA, torch.xpu.synchronize() is to sync all the kernel on the GPU and wait for their completition.

And I can also confirm that I can't seem to allocate over 8GB at all on the A770 16GB (with fp32 and bf16) I get the following error RuntimeError: Native API failed. Native API returns: -997 (The plugin has emitted a backend specific error) -997 (The plugin has emitted a backend specific error)

BA8F0D39 commented 1 year ago

@jingxu10
Does pytorch enforce memory limits to how much you can allocate and the memory bandwidth?

gujinghui commented 1 year ago

@jingxu10 Does pytorch enforce memory limits to how much you can allocate and the memory bandwidth?

No, PyTorch does not. It should be driver limits.

lit199 commented 1 year ago

It seems that some of the operations are using a slow code path. If channels_last is not enabled, convolution forward and convolution weight backward do not use blocked formats and use a slow version (see dnnl_normal.log). dnnl_normal.log If channels_last is enabled, convolution backward functions use the fast blocked format, but forward is even slower (see dnnl_channels_last.log). dnnl_channels_last.log dnnl_log_compare.txt compares the results with and without blocking using benchdnn, and the difference is >30X on a A750. dnnl_log_compare.txt

BA8F0D39 commented 1 year ago

@lit199 The kernel launch times are also very slow. 10x slower than other GPUs

Platform: Intel(R) OpenCL HD Graphics
  Device: Intel(R) Graphics [0x56a0]
    Driver version  : 22.43.30 (Linux x64)
    Compute units   : 512
    Clock frequency : 2400 MHz

    Global memory bandwidth (GBPS)
      float   : 397.87
      float2  : 403.63
      float4  : 407.18
      float8  : 416.18
      float16 : 421.80

    Single-precision compute (GFLOPS)
      float   : 13017.51
      float2  : 11136.49
      float4  : 10402.49
      float8  : 10026.09
      float16 : 9695.57

    Half-precision compute (GFLOPS)
      half   : 19543.72
      half2  : 19489.39
      half4  : 19523.66
      half8  : 19454.95
      half16 : 19336.14

    No double precision support! Skipped

    Integer compute (GIOPS)
      int   : 4380.31
      int2  : 4385.50
      int4  : 4403.38
      int8  : 4273.37
      int16 : 5004.16

    Integer compute Fast 24bit (GIOPS)
      int   : 4361.75
      int2  : 4369.68
      int4  : 4387.98
      int8  : 4265.73
      int16 : 4995.43

    Transfer bandwidth (GBPS)
      enqueueWriteBuffer              : 21.64
      enqueueReadBuffer               : 8.92
      enqueueWriteBuffer non-blocking : 22.81
      enqueueReadBuffer non-blocking  : 9.10
      enqueueMapBuffer(for read)      : 20.58
        memcpy from mapped ptr        : 22.62
      enqueueUnmap(after write)       : 23.62
        memcpy to mapped ptr          : 22.44

    Kernel launch latency : 34.76 us

Kernel Latency of 34.76 us in A770 16 GB is 10x larger than RTX 2080 SUPER of 3.46 us https://github.com/krrishnarraj/clpeak/blob/master/results/NVIDIA_CUDA/GeForce_RTX_2080_Super.log

lit199 commented 1 year ago

@BA8F0D39 Slow kernel launch is bad but is not the main culprit here. IPEX/PyTorch/oneDNN is choosing a slower version when there is a faster version available. Also I am getting 8.97us

Platform: Intel(R) OpenCL HD Graphics
  Device: Intel(R) Graphics [0x56a1]
    Driver version  : 22.49.25018.23 (Linux x64)
    Compute units   : 448
    Clock frequency : 2400 MHz

    Global memory bandwidth (GBPS)
      float   : 398.32
      float2  : 398.83
      float4  : 406.75
      float8  : 408.93
      float16 : 412.17

    Single-precision compute (GFLOPS)
      float   : 11394.35
      float2  : 9748.35
      float4  : 9097.27
      float8  : 8777.80
      float16 : 8501.48

    Half-precision compute (GFLOPS)
      half   : 17126.20
      half2  : 17073.53
      half4  : 17102.17
      half8  : 17042.17
      half16 : 16939.83

    No double precision support! Skipped

    Integer compute (GIOPS)
      int   : 4119.66
      int2  : 4122.17
      int4  : 4135.86
      int8  : 4013.78
      int16 : 4704.14

    Integer compute Fast 24bit (GIOPS)
      int   : 4098.38
      int2  : 4105.63
      int4  : 4123.31
      int8  : 4001.99
      int16 : 4698.86

    Transfer bandwidth (GBPS)
      enqueueWriteBuffer              : 8.75
      enqueueReadBuffer               : 4.81
      enqueueWriteBuffer non-blocking : 11.34
      enqueueReadBuffer non-blocking  : 5.46
      enqueueMapBuffer(for read)      : 10.03
        memcpy from mapped ptr        : 15.42
      enqueueUnmap(after write)       : 11.71
        memcpy to mapped ptr          : 15.81

    Kernel launch latency : 8.97 us

@jingxu10 Setting IPEX_XPU_ONEDNN_LAYOUT=1 seems to add the necessary reorder steps. I am getting 4-5 batches/s after this change. Here is some onednn log after the change. dnnl_reorder.csv

BA8F0D39 commented 1 year ago

@lit199 channels_last=True did improve performance by 10% to 20% on training. IPEX_XPU_ONEDNN_LAYOUT=1 doesn't do anything for A770. Seems like there is more than 10 different bugs causing bad performance from kernel drivers to Intel compute runtime to oneAPI to oneDNN to pytorch.

  1. 8.97 us kernel launch time is 3x slower than other GPUs.
  2. enqueueReadBuffer is 4.81GBPS, which is 2x slower than other GPUs.
  3. Memory transfer failures in OpenCL/Level Zero
    
    [  FAILED  ] UsmSharedMigrateGpuForFillTest/UsmSharedMigrateGpuForFillTest.Test/1
    FAILED assertion ASSERT_CL_SUCCESS(clEnqueueMemFillINTEL(opencl.commandQueue, buffer, &pattern, 1, arguments.bufferSize, 0, nullptr, nullptr))
    value: -59 (CL_INVALID_OPERATION)
    Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/ocl/usm_shared_migrate_gpu_for_fill_ocl.cpp:48

[ FAILED ] UsmSharedMigrateGpuForFillTest/UsmSharedMigrateGpuForFillTest.Test/3 FAILED assertion ASSERT_CL_SUCCESS(clEnqueueMemFillINTEL(opencl.commandQueue, buffer, &pattern, 1, arguments.bufferSize, 0, nullptr, nullptr)) value: -59 (CL_INVALID_OPERATION) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/ocl/usm_shared_migrate_gpu_for_fill_ocl.cpp:48

[ FAILED ] UsmSharedMigrateGpuForFillTest/UsmSharedMigrateGpuForFillTest.Test/5 FAILED assertion ASSERT_CL_SUCCESS(clEnqueueMemFillINTEL(opencl.commandQueue, buffer, &pattern, 1, arguments.bufferSize, 0, nullptr, nullptr)) value: -59 (CL_INVALID_OPERATION) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/ocl/usm_shared_migrate_gpu_for_fill_ocl.cpp:48

[ FAILED ] UsmSharedMigrateGpuForFillTest/UsmSharedMigrateGpuForFillTest.Test/7 FAILED assertion ASSERT_CL_SUCCESS(clEnqueueMemFillINTEL(opencl.commandQueue, buffer, &pattern, 1, arguments.bufferSize, 0, nullptr, nullptr)) value: -59 (CL_INVALID_OPERATION) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/ocl/usm_shared_migrate_gpu_for_fill_ocl.cpp:48

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/49 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/51 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/53 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/55 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/57 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/59 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/61 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/63 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/65 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/67 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/69 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/71 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/73 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/75 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/77 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/79 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/81 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/83 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/85 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/87 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/89 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/91 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/93 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] StreamMemoryTest/StreamMemoryTest.Test/95 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/stream_memory_l0.cpp:173

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/72 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/76 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/80 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/84 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/88 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/92 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/96 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/100 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/104 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/108 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/112 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/116 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/120 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/124 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

[ FAILED ] UsmCopyTest/UsmCopyTest.Test/128 FAILED assertion ASSERT_ZE_RESULT_SUCCESS(zeEventQueryKernelTimestamp(event, &timestampResult)) value: 1 (ZE_RESULT_NOT_READY) Location: /opt/test/compute-benchmarks/compute-benchmarks/source/benchmarks/memory_benchmark/implementations/l0/usm_copy_l0.cpp:101

4. IPC and multithreading support is not existent in pytorch and compute runtime. Using more than num_workers=1 crashes and the compute runtime fails all the multithreading and IPC tests
                        MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=1 workgroupsPerProcess=1 synchronize=0)                                        ERROR
                        MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=1 workgroupsPerProcess=1 synchronize=1)                                        ERROR
                      MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=1 workgroupsPerProcess=300 synchronize=0)                                        ERROR
                      MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=1 workgroupsPerProcess=300 synchronize=1)                                        ERROR
                        MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=2 workgroupsPerProcess=1 synchronize=0)                                        ERROR
                        MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=2 workgroupsPerProcess=1 synchronize=1)                                        ERROR
                      MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=2 workgroupsPerProcess=300 synchronize=0)                                        ERROR
                      MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=2 workgroupsPerProcess=300 synchronize=1)                                        ERROR
                        MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=4 workgroupsPerProcess=1 synchronize=0)                                        ERROR
                        MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=4 workgroupsPerProcess=1 synchronize=1)                                        ERROR
                      MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=4 workgroupsPerProcess=300 synchronize=0)                                        ERROR
                      MultiProcessComputeSharedBuffer(api=l0 tiles=Tile0 processesPerTile=4 workgroupsPerProcess=300 synchronize=1)                                        ERROR
fredlarochelle commented 1 year ago

Any updates on slow GPU-GPU transfer speed? It's way better than it was, but it's still 3-5 slower than a Tesla T4 from Colab for transfer smaller than around 0.01-0.12 GB/s. Over that the performance of the A770 is around where it should be, a bit under double the speed of the T4, when comparing the theorical performance in memory speed of each GPU. The performance is also significantly slower on A770 for a single transfer. Finally, comparing FP16/BF16 with FP32 on A770, FP32 is about half as slow, except for small transfer where it's about the same.

Also, for the weird 4GB/8GB memory issues, I have further isolated the problem. I don't have the -997 error anymore, it only seems to be about array size over around 48900x48900 for BF16 with the error -5 (haven't tested with other data type) and the gibberish over 4GB

import torch
import intel_extension_for_pytorch as ipex

# Returns 16.225243135GB on A770, all good
print(f"The total memory of the card is {torch.xpu.get_device_properties('xpu').total_memory / 1e9}GB.")

# We are able to allocate an array over 4GB, all good, but gibberish
array = torch.rand(48000, 48000, dtype=torch.bfloat16, device='xpu')
print(f"The total allocated memory of the card is {torch.xpu.memory_allocated() / 1e9}GB and the memory of the array is {(array.element_size() * array.nelement()) / 1e6}GB")

# But somewhere in-between a 48800x48800 and 48900x48900 array, we get a -5 (PI_ERROR_OUT_OF_RESSOURCES)
array = torch.rand(48900, 49000, dtype=torch.bfloat16, device='xpu')

# We are also able to allocate over 8GB of memory in multiple arrays, 9.6GB with this example, all good
array = torch.rand(40000, 40000, dtype=torch.bfloat16, device='xpu')
array1 = torch.rand(40000, 40000, dtype=torch.bfloat16, device='xpu')
array2 = torch.rand(40000, 40000, dtype=torch.bfloat16, device='xpu')
print(f"Total memory allocated from the 3 arrays is {torch.xpu.memory_allocated() / 1e9}GB")

Take note that each "section of code" is running through a restarted kernel, especially since the PI_ERROR_OUT_OF_RESSOURCES error makes the kernel hang.

For gibberish for an array over 4GB, see #325.

fengyuan14 commented 1 year ago

@jingxu10 @fredlarochelle

As @fredlarochelle said aten::copy_ is extremly slow

The bottleneck seems to be at GPU to GPU transfer

Example code testing GPU to GPU memory copy

import torch
import torchvision

import intel_extension_for_pytorch as ipex

import time

sizes = [1000, 1000,1000,1000,1000,2000, 3000, 10000, 20000,30000,40000,42000,43000,44000,45000,]

for size in sizes:

    array0 = torch.rand(size, size, dtype=torch.bfloat16).to("xpu")
    array1 = torch.rand(size, size, dtype=torch.bfloat16).to("xpu")

    torch.xpu.synchronize()
    start = time.time()
    array0 = torch.clone(array1)
    torch.xpu.synchronize()
    end = time.time()
    transferrate = (size*size*16)/(end - start)
    datasize = (size*size*16)
    print("==========")
    print("Transfering " + str(datasize/8E9) + " GB")
    print("Bandwidth " + str(transferrate/8E9) + " GB/s")
    print("==========")

    torch.xpu.empty_cache()

It bottlenecks at 100 GB/s max transfer rate when the A770 bandwidth is 512 GB/s. Also, the most important thing is, transferring low amounts of data is extremely slow because small batches only achieves 30 GB/s which is 6% utilization of the maximum theoretical bandwidth of 512 GB/s.

==========
Transfering 0.002 GB
Bandwidth 29.746836879432625 GB/s
==========
==========
Transfering 0.008 GB
Bandwidth 50.45779248120301 GB/s
==========
==========
Transfering 0.018 GB
Bandwidth 71.29128611898017 GB/s
==========
==========
Transfering 0.2 GB
Bandwidth 80.71401905128451 GB/s
==========
==========
Transfering 0.8 GB
Bandwidth 91.80419151846785 GB/s
==========
==========
Transfering 1.8 GB
Bandwidth 102.16166711772665 GB/s
==========
==========
Transfering 3.2 GB
Bandwidth 103.13888713854288 GB/s
==========
==========
Transfering 3.528 GB
Bandwidth 102.9577837521917 GB/s
==========

@BA8F0D39 @fredlarochelle We verified your case locally, got similar mem bandwidth for BF16, ~100GB/s. And we also tried Float, ~200GB/s. And clpeak seems showing 400GB/s. The issue might be caused instruction bound. Data from memory access instructions cannot feed DDR/HBM well. We are checking all stacks for SIMD enabling on ARC.

fredlarochelle commented 1 year ago

@arthuryuan1987 I get an higher bandwidth than that for BF16 (same for FP16), ~200GB/s (peak of 222GB/s) for transfers over ~0.2GB/s (under that it's lower and drastically lower under ~0.01-0.12 GB/s).

On my system, it's FP32 that maxes out at ~110GB/s.

For FP16 and BF16 and bigger transfer, the "percentage of the theorical max bandwith" I get is about the same as a Tesla T4 on Colab. The problems are really with smaller transfers and FP32. Also, single transfers are really slow,

BA8F0D39 commented 1 year ago

@fredlarochelle Using 5.19 out of tree kernel drivers https://dgpu-docs.intel.com/driver/client/overview.html , export IPEX_XPU_ONEDNN_LAYOUT=1 , and BF16, the training takes 1s per epoch. This is still slower than Tesla T4, A770 should be 2x faster than Tesla T4 but somehow the OneDNN memory layout can increase performance.

BA8F0D39 commented 1 year ago

@arthuryuan1987 I also noticed IPEX crashes with multiple workers. Is multithreading disabled? Can multithreading speedup multiple GPU memory transfers?

fredlarochelle commented 1 year ago

@BA8F0D39 Can confirm, I get around 1s per epoch too, but still way better than it was in February. And the performance is also definitely better with bigger datasets/bigger memory transfers and using some optimization (this simple example in the doc is not optimized at all), but the performance is probably still a bit under one order of magnitude off for this particular example.

Tesla T4 theorical specs are 65.13 TFLOPS at FP16 with 320 GB/s of memory bandwith vs the A770 157.23 TFLOP/s for FP16/B16 (65536 ops per clock cycle) and 560 GB/s. So, it can possible be even more than 2x a T4.

For multiple workers, I have no problem with setting num_workers to something other than the default value of 0 in the dataloader if it is what you are referring to.

fengyuan14 commented 1 year ago

@fredlarochelle

  1. As to small size copy (your case is cache irrelevant, all data should transfer on DDR/HBM), there is no enough instructions issued to feed DDR/HBM. The case could not show us a high bandwidth. Small size problem is not a problem of DDR/HBM throughput, but a problem of DDR/HBM latency.
  2. As to FP32 vs FP16/BF16, regarding same number of elements, in another word, same number of instructions there, FP32 gets double size of bytes. When the case bounds on instructions (small number of elements), you will see same instruction latency, but double size of bytes transferred. FP32's bandwidth will be double of FP16's. FP32 should be better, if using same SIMD. I might not explain your problem right now, possibly is different SIMD applicability for data types.
  3. If the case bounds on DDR/HBM, leveraging SIMD optimization to issue more memory access in each window of DDR/HBM (throughput preference, generally high latency/window high throughput), we will see high bandwidth.
fengyuan14 commented 1 year ago

@fredlarochelle Using 5.19 out of tree kernel drivers https://dgpu-docs.intel.com/driver/client/overview.html , export IPEX_XPU_ONEDNN_LAYOUT=1 , and BF16, the training takes 1s per epoch. This is still slower than Tesla T4, A770 should be 2x faster than Tesla T4 but somehow the OneDNN memory layout can increase performance.

@BA8F0D39 That's right. IPEX exposes oneDNN layout to users to reach better performance. Especially in the case, not bounded on computation (computation latency is trivial), to enable oneDNN layout can avoid additional layout reordering (additional memory access).

fengyuan14 commented 1 year ago

@arthuryuan1987 I also noticed IPEX crashes with multiple workers. Is multithreading disabled? Can multithreading speedup multiple GPU memory transfers?

I suppose you are saying using multi-threading to feed memory port. Your intention is right. To make memory port fed well,

  1. Increase instructions (should be what you are trying)
  2. Increase data of each instruction w/o increasing instructions (SIMD).

Regarding cases here, your approach should not work,

  1. Occupancy of ARC GPU cores is high in your case, in another word, number of instructions is enough.
  2. So far, regarding IPEX execution mode, each kernel submission occupies whole GPU (all cores). Hence copy kernels submitted multiple CPU cores won't benefit any.
BA8F0D39 commented 1 year ago

@fredlarochelle @arthuryuan1987
I found something weird with latency. Printing a float32 takes 1340 us in IPEX. This is fine. Transferring a single float32 number takes 0.142s. Why does this take so long? The GPU to GPU transfer rate is 224.56 bit/s for 1 float32.

import time
import torch
import torchvision.models as models

import numpy as np
import intel_extension_for_pytorch as ipex

torch.manual_seed(0)

x = torch.rand(1, 1, dtype=torch.float32, device='xpu')

torch.xpu.synchronize()
start = time.time()
print(x.cpu())
end = time.time()

print("Print Time in Seconds: %.20f " % (end - start))

torch.manual_seed(2)

x = torch.rand(1, 1, dtype=torch.float32, device='xpu')
y = torch.rand(1, 1, dtype=torch.float32, device='xpu')

torch.xpu.synchronize()
start = time.time()
y = x.clone()
print(y.cpu())
end = time.time()

print("Data Transfer Time in Seconds: %.20f " % (end - start))

On A770 16 GB

tensor([[0.9179]])
Print Time in Seconds: 0.00134086608886718750 
tensor([[0.9696]])
Data Transfer Time in Seconds: 0.14255475997924804688