Open KeremTurgutlu opened 1 year ago
@KeremTurgutlu How do we know that it originates in the triton_autotuner?
@hawkinsp Is it hard to propagate C++ stack trace along with the error?
@KeremTurgutlu How do we know that it originates in the triton_autotuner?
I am not 100% if that's the root cause but I should've probably pasted this as well:
[triton_autotuner.cc:271] failure: internal: ptxas exited with non-zero error code 65280, output: ptxas /var/tmp/tempfile-gpu-3b3b9d27-29193-5fb10408aefa4, line 234; error : rounding modifier required for instruction 'cvt'
I was able to successfully run the code with from scratch nvidia driver, cuda (12.1), cudnn and jax installation
1) Launched a A100 in Google Cloud with base ubuntu 18.04 image.
2) Install latest nvidia driver with cuda 12.1.
3) Install miniconda and create a conda env.
4) Install jax and cudnn.
# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
5) Install t5x from source and install t5.
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA A100-SXM4-80GB On | 00000000:00:05.0 Off | 0 |
| N/A 40C P0 60W / 400W| 73901MiB / 81920MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 9080 C python3 73890MiB |
+---------------------------------------------------------------------------------------+
nvcc not installed.
Oh this is CUDA12, probably that would explain it --- we had some bugs filed on CUDA12 before.
I think this is actually a case of too old a CUDA installation, not the other way around. The image is named c0-deeplearning-common-cu113-v20230501-debian-10
: note "cu113".
JAX is built for CUDA 11.8 (or CUDA 12), and if I recall correctly Ampere GPU support wasn't added until longer after 11.3.
Can you update to CUDA 11.8 or newer?
Note nvidia-smi
reports the CUDA version of the driver, not the installed libraries. You need both to be sufficiently new.
Sorry if it was not clear but what I wanted mention was issue was fixed when I installed cuda 12 from scratch instead using the Google Cloud image.
Recently got this error (might be related to 0.4.9 release looking into it):
(myenv) keremturgutlu@gpu:~$ python -c "import jax; print(jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0]).device())"
2023-05-10 07:12:19.701059: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2092, in ones
return lax.full(shape, 1, _jnp_dtype(dtype))
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1190, in full
return broadcast(fill_value, shape)
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 756, in broadcast
return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 784, in broadcast_in_dim
return broadcast_in_dim_p.bind(
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached
return f(*args, **kwargs)
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
self._executable = UnloadedMeshExecutable.from_hlo(
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Edit: Tried again by recreating a new instance, and I wasn't able to reproduce the error.
Getting similar when trying to run a custom model on an A6000 with pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
over Runpod's PyTorch or Tensorflow images. Tried both cuda11 and cuda12, same issue.
https://pastebin.com/raw/MUsYZje8
Update: Seems to only happen with bfloat16. Works fine with float32
@euclaise Do you have another copy of ptxas
installed? Is there one in your PATH
? My strong suspicion is still "we are finding an ancient ptxas
".
@hawkinsp I don't, but it should be whatever is used here https://hub.docker.com/r/runpod/pytorch/
@hawkinsp
ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_Mar__8_18:17:32_PST_2022
Cuda compilation tools, release 11.6, V11.6.124
Build cuda_11.6.r11.6/compiler.31057947_0
After some testing, it appears to be caused by me accidentally mixing bflaot16
values with float32
ones. Seems a check for that is missing somewhere prior to assembly.
Description
Getting the following error when trying to run code on a A100 80GB Google Cloud Debian Deep Learning image (c0-deeplearning-common-cu113-v20230501-debian-10). This code is tested and works on TPU (using t5x library). I don't know if this error is related to my setup but after creating the instance before running the code these are the steps I took:
1) Created a new conda environment with py3.9
2) Install latest jax cuda
pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
, which is 4.8.0 as of writing.3) Clone t5x library and install editable local version
-e git+https://github.com/google-research/t5x.git@2b010160e7fe8a4505a6d1032a7b737a633636e5#egg=t5x
.4) Install extra dep:
pip install t5
.5) Upgrade CUDNN library to 8.6.0 as jax complained it requires at least that version by manually downloading cudnn-linux-x86_64-8.6.0.163_cuda11-archive.tar.xz and then running the following:
6) Verified GPU is usable by jax:
The following is the error I get when running a t5x pretraining script using train.py.
What jax/jaxlib version are you using?
jax==0.4.7 jaxlib==0.4.7+cuda11.cudnn86
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.9.16 | packaged by conda-forge | (main, Feb 1 2023, 21:39:03) [GCC 11.3.0] on linux
NVIDIA GPU info