Closed fredlarochelle closed 1 year ago
@fredlarochelle great question :). we are fully aware the important of XLA ecosystem (JAX, TF-XLA, PT-XLA), and we are actively discussing with Google to bring Intel GPU to XLA, the direction we aligned is to use PjRT plugin interface to integrate Intel GPU in OpenXLA first, we are internally working on that, I expect we will open our XLA solution to support JAX in near future, please keep monitoring our repo, I will also update here when Intel GPU XLA solution is out.
PS: you can check this RFC to understand more about integration design if you have interesting. https://github.com/openxla/community/pull/33
Hi @fredlarochelle,
Good news! Intel Extension for TensorFlow v1.2.0 adopted PJRT plugin interface to implement Intel GPU backend for OpenXLA experimental support, you can follow the instructions outlined here to build the necessary xla extension using bazel and get started with the provided JAX example.
Also be aware of gcc and g++ version missmatching in your environment as it could lead to dependency issues.
Thank you, and let us know if you have any other questions!
@yehudaorel Awesome! I will set aside some time over the next few days to try this out and definitely get back to you!
@yehudaorel I get that the build is failing. Running bazel build --verbose_failures --config=jax -c opt //itex:libitex_xla_extension.so > build_output.txt 2>&1
to capture the output, I get the following build_output.txt.
For info, the system is a Intel Xeon E5-2695 v3 with 128gb of ram and an Intel Arc A770. Based on previous experience from building IPEX, I used the ats-m150
for the device type for AOT (the A770 is not in you documentation).
Do I need to build LLVM like IPEX? Or should I simply try to build XLA in the same conda env as IPEX with the already build LLVM?
Hi, @fredlarochelle
I get that the build is failing. Running
bazel build --verbose_failures --config=jax -c opt //itex:libitex_xla_extension.so > build_output.txt 2>&1
to capture the output, I get the following build_output.txt.From the build output:
/usr/bin/ld: cannot find -lstdc++: No such file or directory
Depending on how you setup your environment this issue might be caused from mismatched versions of gcc and g++, as I mentioned in my first comment, this is a known issue with the dpcpp compiler and could be the case here as well. Take a look at the solution provided in that ticket and see if the same applies to your env.For info, the system is a Intel Xeon E5-2695 v3 with 128gb of ram and an Intel Arc A770. Based on previous experience from building IPEX, I used the
ats-m150
for the device type for AOT (the A770 is not in you documentation).Although the A770 is not mentioned specifically in the docs yet, its is based on the same DG2-512 chip variant as the A730M & Flex 170:
acm-g10
. I would recommend trying bothacm-g10
&ats-m150
.
Keep in mind, ARC A-series GPU's support is experimental thereby being highly sensitive to breaking, here is the list of GPUs that were tested & verified: | GPU | device type |
---|---|---|
Intel® Data Center GPU Flex Series 170 | ats-m150 | |
Intel® Data Center GPU Flex Series 140 | ats-m75 | |
Intel® Data Center GPU Max Series | pvc | |
Intel® Arc™ A730M | acm-g10 | |
Intel® Arc™ A380 | acm-g11 |
Also take a look at the ITEX provided docs for building from source procedure: https://github.com/intel/intel-extension-for-tensorflow/blob/main/docs/install/how_to_build.md
Apologies for the delay. I have taken a closer look and it does look pretty good soo far!
While I haven't conducted extensive testing, it seems to be working great!
However, I have encountered three problems. First, it doesn't work properly in a Jupyter notebook. Whenever I attempt to run anything apart from the imports, the notebook crashes with a message similar to the following in the logs [warn] Cell completed with errors { message: 'Canceled future for execute_request message before replies were done'}
.
Additionally, I am experiencing OOM errors on the A770 even at relatively low memory usage (around 3GB).
Finally, I need to take a deeper look at what is going on, but Jax test suite doesn't want to run at all.
Regarding the lsdtc++
error, it was a simple fix. I resolved it by installing g++-12
. It appear that different components in the Intel ecosystem require different version of GCC. For instance, Ubuntu 22.04 and IPEX both rely on GNU 11. However, the i915 drivers seem to requires gcc-12
without installing the corresponding g++-12
.
Btw, my tests on the A770 were done with the device type set as ats-m150
, I haven't tested acm-g10
.
Any reasons why float16 operations are about 1.3-1.4x faster than bfloat16?
@fredlarochelle
Thanks for your feedback.
Any reasons why float16 operations are about 1.3-1.4x faster than bfloat16?
Could you share some more details? The test case, test code, environment etc.
@fredlarochelle
OOM errors on the A770 even at relatively low memory usage (around 3GB).
Finally, I need to take a deeper look at what is going on, but Jax test suite doesn't want to run at all.
What's you environment? Could you check them as following and put the output here?
======================== Check Python ========================
python3.9 is installed.
==================== Check Python Passed =====================
========================== Check OS ==========================
OS ubuntu:22.04 is Supported.
====================== Check OS Passed =======================
====================== Check Tensorflow ======================
tensorflow2.10 is installed.
================== Check Tensorflow Passed ===================
=================== Check Intel GPU Driver ===================
The script ./tools/env_check.sh
can't run if we follow your installation instructions here, since we are not building Tensorflow, only Jax. Can I run a second Bazel build command to build Tensorflow without messing with Jax?
For my environment, I am on Ubuntu 22.04 with an A770, in a conda env with Python 3.10. Drivers nor oneAPI should be an issue here, I am building and running IPEX. As an alternative to your env_check.sh
script, here is the output of collect_env.py
from IPEX, take note that I am using GCC 12.1.0 to build Jax not 11.3.0 like for IPEX:
Collecting environment information...
PyTorch version: 2.0.0a0+gite9ebda2
PyTorch CXX11 ABI: Yes
IPEX version: 2.0.110+gitb7412a4
IPEX commit: b7412a42
Build type: Release
OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0
Clang version: N/A
IGC version: 2023.1.0 (2023.1.0.20230320)
CMake version: version 3.26.4
Libc version: glibc-2.35
Python version: 3.10.11 (main, May 16 2023, 00:28:57) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.0-46-generic-x86_64-with-glibc2.35
Is XPU available: True
DPCPP runtime version: 2023.1.0
MKL version: 2023.1.0
GPU models and configuration:
[0] _DeviceProperties(name='Intel(R) Arc(TM) A770 Graphics', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=0, total_memory=15473MB, max_compute_units=512, gpu_eu_count=512)
Intel OpenCL ICD version: 23.17.26241.21-647~22.04
Level Zero version: 1.3.26241.21-647~22.04
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 28
On-line CPU(s) list: 0-27
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU E5-2695 v3 @ 2.30GHz
CPU family: 6
Model: 63
Thread(s) per core: 2
Core(s) per socket: 14
Socket(s): 1
Stepping: 2
CPU max MHz: 3300.0000
CPU min MHz: 1200.0000
BogoMIPS: 4589.29
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm cpuid_fault epb invpcid_single pti intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm xsaveopt cqm_llc cqm_occup_llc dtherm ida arat pln pts md_clear flush_l1d
Virtualization: VT-x
L1d cache: 448 KiB (14 instances)
L1i cache: 448 KiB (14 instances)
L2 cache: 3.5 MiB (14 instances)
L3 cache: 35 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-27
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
...
As for more information, nothing runs in a Jupyter notebook, the Jupyter logs say the error is 'Canceled future for execute_request message before replies were done'
, anyway here is a minimum reproducer (MRE):
# first cell
!source /opt/intel/oneapi/setvars.sh
# second cell
import jax
import jax.numpy as jnp
# third cell
key = jax.random.PRNGKey(0) # Crashes here, but could be anything calling jax from my testing
For the OOM error, here is a quick matmul tflops test I wrote. Setting the matrix_size
at 28000 for example, the OOM error randomly happens with error saying it failed trying to allocate 2.92 GiB:
import time
import jax
import jax.numpy as jnp
key = jax.random.PRNGKey(0)
matrix_size = 28000 # Adjust matrix size as needed
num_runs = 20 # Number of runs to average
def tflops_run() -> float:
# Generate random matrices
matrix_a = jax.random.normal(key, (matrix_size, matrix_size), dtype=jnp.bfloat16)
matrix_b = jax.random.normal(key, (matrix_size, matrix_size), dtype=jnp.bfloat16)
# Perform matrix multiplication and measure time
start_time = time.time()
result = jnp.matmul(matrix_a, matrix_b)
end_time = time.time()
# Calculate TFLOPS
elapsed_time = end_time - start_time
num_operations = 2 * matrix_size**3
tflops = num_operations / (elapsed_time * 1e12)
return tflops
def measure_device_tflops(num_runs: int) -> float:
total_tflops = 0.0
first_run = True
for _ in range(num_runs):
total_tflops += tflops_run()
# Not counting the first run
if first_run:
first_run = False
total_tflops = 0.0
average_tflops = total_tflops / num_runs
return average_tflops
average_tflops = measure_device_tflops(num_runs)
print(f"Average Device TFLOPS: {average_tflops}")
Finally, for the performance disparity between bfloat16
and float16
I did a new build this morning and it seems fine now??
EDIT: I haven't taken a look yet at the issues I am having running Jax test suite.
Hi, I encountered another weird bug where importing either seaborn
or matplotlib
before jax
leads to a core dump with the message Aborted (core dumped)
, but importing both after jax
works without issues. I tested the same scenario on Colab, and it works without any issues on either situation there. It's worth noting that Colab is using an older version of Jax, soo the issue might be related to Jax itself. Should I proceed with reporting this bug to the Jax upstream?
Has Flax been tested? I am encountering some issues while testing it. For instance, just creating a small model:
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
from typing import Any
class SimpleModel(nn.Module):
num_hidden : int
num_outputs : int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.num_hidden)(x)
x = nn.relu(x)
x = nn.Dense(features=self.num_outputs)(x)
return x
model = SimpleModel(num_hidden=8, num_outputs=1)
print(model)
I get the following error, while it works fine in Colab:
Traceback (most recent call last):
File "/home/fred/projects/jax_test/flax_mre.py", line 3, in <module>
import flax
File "/home/fred/miniconda3/envs/itex_build_2/lib/python3.10/site-packages/flax/__init__.py", line 22, in <module>
from . import core
File "/home/fred/miniconda3/envs/itex_build_2/lib/python3.10/site-packages/flax/core/__init__.py", line 16, in <module>
from .frozen_dict import (
File "/home/fred/miniconda3/envs/itex_build_2/lib/python3.10/site-packages/flax/core/frozen_dict.py", line 50, in <module>
@jax.tree_util.register_pytree_with_keys_class
AttributeError: module 'jax.tree_util' has no attribute 'register_pytree_with_keys_class'. Did you mean: 'register_pytree_node_class'?
Also, to avoid dependency conflicts, I found that it is important to include Flax installation within the same pip command as Jax. For instance, I used the following command to install the required packages pip install tensorflow==2.12.0 jax==0.4.4 jaxlib==0.4.4 flax optax
.
EDIT: Should I just start opening new issues for every bug I find?
I solved the issue I was having with the Jax test suite by executing git checkout jaxlib-v.0.4.4
.
The tests can't be run in parallel without crashing from Linux killing the process due to OOM errors, even when it is set to only 2 workers on a system that has 64GB of RAM. Also, it kinda messes with the system and it needed a reboot after every try.
I was able to start a test run with a single worker (took around 1h) after disabling tests/api_test.py::AutodidaxTest::test_autodidax_smoketest
and got only 27 failed test. Most are in tests/array_interoperability_test.py
and 3 in tests/checkify_test.py
and are related to unknown backend or unimplemented errors.
Overall not too bad!
The script
./tools/env_check.sh
can't run if we follow your installation instructions here, since we are not building Tensorflow, only Jax. Can I run a second Bazel build command to build Tensorflow without messing with Jax?
@fredlarochelle Thank you for your feedback. Yes, from the output of pytorch collect_env.py, the GPU driver and oneAPI tools are OK.
Your arc 770 is 16G memory, it's a problem that "the OOM error randomly happens with error saying it failed trying to allocate 2.92 GiB".
I tried your code of matmul on a Flex-170 (I have no arc 770 available now) without the OOM issue. We install ubuntu 22.04 sever. From your kernel version Linux-5.19.0-46-generic-x86_64-with-glibc2.35
, it should be desktop.
In the output of, there is
36 2023-06-30 16:07:19.693569: I itex/core/devices/bfc_allocator.cc:29] Set memory limit to 15386382336 Bytes
Could you check the memory limit in the output of the quick matmul tflops
test you wrote (Setting the matrix_size at 28000).
There are some several allocate memory size but deallocate are 0. Not sure python reclaimed the memory in time. Suppose the memory are not reclaimed in time, it may randomly happen out of memory.
37642 2023-06-30 16:07:21.476930: I itex/core/compiler/xla/stream_executor/stream_executor_pimpl.cc:291] Called StreamExecutor::Allocate(size=3148681344, memory_space=0) returns 0xffff81ae00000000
37643 2023-06-30 16:07:22.004857: I itex/core/compiler/xla/stream_executor/stream_executor_pimpl.cc:217] Called StreamExecutor::Deallocate(mem=0xffff81ae00000000) mem->size()=3148681344
37644 2023-06-30 16:07:22.004901: I itex/core/compiler/xla/stream_executor/temporary_memory_manager.cc:69] deallocated 0 finalized temporaries
37645 2023-06-30 16:07:22.557165: I itex/core/compiler/xla/service/stream_pool.cc:61] [stream=0x5621c7d2e8b0,impl=0x5621c7d10a10] StreamPool returning ok stream
37646 2023-06-30 16:07:22.557207: I itex/core/compiler/xla/service/gpu/gpu_executable.cc:477] GpuExecutable::ExecuteAsyncOnStreamImpl(jit_matmul) time: 1.52 s (cumulative: 1.89 s, max: 1.52 s, #called: 10)
37647 2023-06-30 16:07:22.557222: I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1989] Replica 0 partition 0 completed; ok=1
37648 2023-06-30 16:07:22.557354: I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2268] Replicated execution complete.
37649 2023-06-30 16:07:22.557542: I itex/core/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1208] PjRtStreamExecutorBuffer::Delete
Till now, flex-gpu are officially suppored and fully verified, but ARC are experimental.
Here is my test environmnet: I install tensorflow 2.12 and intel-extension-for-tensorflow[gpu] 1.2.0, built libitex_xla_extension.so only.
I appear to have discovered a potential solution to address the OOM errors encountered during the matmul test. Thus far, I haven't been able to reproduce the error using the following approach. Instead of utilizing the ats-m150
device type for AOT, I opted for acm-g10
in a new build. However, further testing is required to determine if this change introduces any other issues.
Upon this change, I conducted a brief retest of the previously identified problems, and unfortunately, they still persist. Regarding the Jax test suite, it now appears to have the ability to run in parallel. However, numerous tests that are passing with a single worker are failing.
In terms of the operating system, I am currently running Ubuntu server. To clarify, should I switch to Ubuntu desktop or stay with Ubuntu server? Technically, the kernel remains largely the same between the two versions, with the main distinction being the absence of a desktop environnment in the server edition.
@fredlarochelle It sounds good that OOM issue was resovled!
Could you summary which issues are resovled and which are still open?
nothing runs in a Jupyter notebook,
status: Open or close?
It seemed a configure/setting issue.
Out of Memory
Status: Resolved. AOT build with acm-g10 (instead of ats-m150 device type)
import seaborn or matplotlib before jax , Aborted (core dumped) import seaborn or matplotlib after jax, it's OK.
it works without any issue for the same scenario on Colab (colab jax with old version).
Status: Open or resoved?
Flax issue
File "/home/fred/miniconda3/envs/itex_build_2/lib/python3.10/site-packages/flax/core/frozen_dict.py", line 50, in
From the message, it seems the class has been changed. It's should be version mismatch issue.
status: Open or closed?
Here is the summary and I have added a couple:
./tools/env_check.sh
script to run without Tensorflowpytest -n auto tests
) doesn't work.v0.4.4
that dates back from February, the latest release is v0.4.13
. That might solve some of the issues, but bring new ones too.Other than that, from my still limited testing, it's impressive, performance seems pretty good.
Hi @fredlarochelle,
Thanks for you feedback.
This project focus on intel extension for tensorflow. Now there are lots of tensorflow feature to enabled on Flex/ARC GPU.
We will try our best to help you resolve issues on platform/hardware. For the Jax test suite coverage and non-hardware specific system setup/configure, we cannot cover them now due to resource in short.
Here is my suggestion:
OPEN - Nothing runs in a Jupyter notebook (it doesn't work with both VSCode and with Jupyter on my system/setup)
Could you check other python applications work in Jupyter notebook on you system?
RESOLVED - Out of memory -> acm-g10 instead of ats-m150 OPEN - import seaborn or matplotlib before Jax doesn't work
Could this reproduced on Nvidia platform? I just want to make sure it's Intel extension for tensorflow issue or not?
OPEN - Flax issues (I will take a deeper look at the dependencies/version mistmatch)
Looks data structure changed, some memebers is not avabilbed. Version mismatch.
OPEN - Update the ./tools/env_check.sh script to run without Tensorflow
Intel Extension for Tensorflow (ITEX) depedens on Tensorflow. Tensorflow is a must. Wont' fix.
OPEN - Running Jax test suite in parallel (pytest -n auto tests) doesn't work.
Keept it at Low priority for us now.
OPEN - 27 tests failing in the Jax test suite when run with a single worker (someone probably should try the test suite to see if the issue is with my system/setup)
Keept it at Low priority unless specific issue about ITEX is identified.
OPEN - Updating Jax and Jaxlib to a newer release? We are on v0.4.4 that dates back from February, the latest release is v0.4.13. That might solve some of the issues, but bring new ones too.
@feng-intel, do you have plan to support newer jax release?
Other than that, from my still limited testing, it's impressive, performance seems pretty good.
@fredlarochelle
For these software configuration and jax test suite coverage, we have no resources to support it currently.
If you have specific issue of ITEX, for example, look into an failed test Jax case that success on other platform but fail on ITEX, please start a new topic, we are happy to support it.
For this XLA support, it is supported now.Can I close it?
Let's close it.
As part of this extension is any work being done on adding support to XLA for Intel GPUs? With support for XLA, Intel GPUs could work with the whole Jax ecosystem.