jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.34k stars 2.78k forks source link

triton_autotuner: Rounding modifier required for instruction 'cvt' #15900

Open KeremTurgutlu opened 1 year ago

KeremTurgutlu commented 1 year ago

Description

Getting the following error when trying to run code on a A100 80GB Google Cloud Debian Deep Learning image (c0-deeplearning-common-cu113-v20230501-debian-10). This code is tested and works on TPU (using t5x library). I don't know if this error is related to my setup but after creating the instance before running the code these are the steps I took:

1) Created a new conda environment with py3.9

2) Install latest jax cuda pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html, which is 4.8.0 as of writing.

3) Clone t5x library and install editable local version -e git+https://github.com/google-research/t5x.git@2b010160e7fe8a4505a6d1032a7b737a633636e5#egg=t5x.

4) Install extra dep:pip install t5.

5) Upgrade CUDNN library to 8.6.0 as jax complained it requires at least that version by manually downloading cudnn-linux-x86_64-8.6.0.163_cuda11-archive.tar.xz and then running the following:

$ sudo cp cudnn-*-archive/include/cudnn*.h /usr/local/cuda/include 
$ sudo cp -P cudnn-*-archive/lib/libcudnn* /usr/local/cuda/lib64 
$ sudo chmod a+r /usr/local/cuda/include/cudnn*.h /usr/local/cuda/lib64/libcudnn*

6) Verified GPU is usable by jax:

>>> import jax
>>> jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0]).device()
StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

The following is the error I get when running a t5x pretraining script using train.py.

Traceback (most recent call last):
  File "/home/keremturgutlu/t5x/t5x/train.py", line 835, in <module>
    config_utils.run(main)
  File "/home/keremturgutlu/t5x/t5x/config_utils.py", line 214, in run
    gin_utils.run(main)
  File "/home/keremturgutlu/t5x/t5x/gin_utils.py", line 129, in run
    app.run(
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/keremturgutlu/t5x/t5x/train.py", line 788, in main
    _main(argv)
  File "/home/keremturgutlu/t5x/t5x/train.py", line 830, in _main
    train_using_gin()
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/keremturgutlu/t5x/t5x/train.py", line 614, in train
    trainer.compile_train(dummy_batch)
  File "/home/keremturgutlu/t5x/t5x/trainer.py", line 538, in compile_train
    self._compiled_train_step = self._partitioner.compile(
  File "/home/keremturgutlu/t5x/t5x/partitioning.py", line 805, in compile
    return partitioned_fn.lower(*args).compile()
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/stages.py", line 600, in compile
    self._lowering.compile(**kw),
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2836, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3048, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 234; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 238; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 242; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 246; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 250; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 254; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 258; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 262; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 266; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 270; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 274; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 278; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 282; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 286; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 290; error   : Rounding modifier required for instruction 'cvt'
ptxas /var/tmp/tempfile-gpu-b325a549-19214-5fb10904b7b5a, line 294; error   : Rounding modifier required for instruction 'cvt'
ptxas fatal   : Ptx assembly aborted due to errors

What jax/jaxlib version are you using?

jax==0.4.7 jaxlib==0.4.7+cuda11.cudnn86

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.9.16 | packaged by conda-forge | (main, Feb 1 2023, 21:39:03) [GCC 11.3.0] on linux

NVIDIA GPU info

Sun May  7 01:53:18 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:05.0 Off |                    0 |
| N/A   34C    P0    82W / 400W |      0MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Mon_May__3_19:15:13_PDT_2021
Cuda compilation tools, release 11.3, V11.3.109
Build cuda_11.3.r11.3/compiler.29920130_0
cheshire commented 1 year ago

@KeremTurgutlu How do we know that it originates in the triton_autotuner?

cheshire commented 1 year ago

@hawkinsp Is it hard to propagate C++ stack trace along with the error?

KeremTurgutlu commented 1 year ago

@KeremTurgutlu How do we know that it originates in the triton_autotuner?

I am not 100% if that's the root cause but I should've probably pasted this as well:

[triton_autotuner.cc:271] failure: internal: ptxas exited with non-zero error code 65280, output: ptxas /var/tmp/tempfile-gpu-3b3b9d27-29193-5fb10408aefa4, line 234; error : rounding modifier required for instruction 'cvt'

I was able to successfully run the code with from scratch nvidia driver, cuda (12.1), cudnn and jax installation

1) Launched a A100 in Google Cloud with base ubuntu 18.04 image.

2) Install latest nvidia driver with cuda 12.1.

3) Install miniconda and create a conda env.

4) Install jax and cudnn.


# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

5) Install t5x from source and install t5.

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-80GB           On | 00000000:00:05.0 Off |                    0 |
| N/A   40C    P0               60W / 400W|  73901MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A      9080      C   python3                                   73890MiB |
+---------------------------------------------------------------------------------------+

nvcc not installed.

cheshire commented 1 year ago

Oh this is CUDA12, probably that would explain it --- we had some bugs filed on CUDA12 before.

hawkinsp commented 1 year ago

I think this is actually a case of too old a CUDA installation, not the other way around. The image is named c0-deeplearning-common-cu113-v20230501-debian-10: note "cu113".

JAX is built for CUDA 11.8 (or CUDA 12), and if I recall correctly Ampere GPU support wasn't added until longer after 11.3.

Can you update to CUDA 11.8 or newer?

Note nvidia-smi reports the CUDA version of the driver, not the installed libraries. You need both to be sufficiently new.

KeremTurgutlu commented 1 year ago

Sorry if it was not clear but what I wanted mention was issue was fixed when I installed cuda 12 from scratch instead using the Google Cloud image.

KeremTurgutlu commented 1 year ago

Recently got this error (might be related to 0.4.9 release looking into it):

(myenv) keremturgutlu@gpu:~$ python -c "import jax; print(jax.device_put(jax.numpy.ones(1), device=jax.devices('gpu')[0]).device())"
2023-05-10 07:12:19.701059: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 2092, in ones
    return lax.full(shape, 1, _jnp_dtype(dtype))
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 1190, in full
    return broadcast(fill_value, shape)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 756, in broadcast
    return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 784, in broadcast_in_dim
    return broadcast_in_dim_p.bind(
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 117, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
    return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/keremturgutlu/miniconda3/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Edit: Tried again by recreating a new instance, and I wasn't able to reproduce the error.

euclaise commented 1 year ago

Getting similar when trying to run a custom model on an A6000 with pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html over Runpod's PyTorch or Tensorflow images. Tried both cuda11 and cuda12, same issue.

https://pastebin.com/raw/MUsYZje8

Update: Seems to only happen with bfloat16. Works fine with float32

hawkinsp commented 1 year ago

@euclaise Do you have another copy of ptxas installed? Is there one in your PATH? My strong suspicion is still "we are finding an ancient ptxas".

euclaise commented 1 year ago

@hawkinsp I don't, but it should be whatever is used here https://hub.docker.com/r/runpod/pytorch/

euclaise commented 1 year ago

@hawkinsp

ptxas: NVIDIA (R) Ptx optimizing assembler
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_Mar__8_18:17:32_PST_2022
Cuda compilation tools, release 11.6, V11.6.124
Build cuda_11.6.r11.6/compiler.31057947_0

After some testing, it appears to be caused by me accidentally mixing bflaot16 values with float32 ones. Seems a check for that is missing somewhere prior to assembly.