jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

Attempt to perform FFT requiring batched cufft plan of size > 4 GB causes coredump #9591

Open erdmann opened 2 years ago

erdmann commented 2 years ago

Attempting to perform a somewhat large FFT results in a coredump rather than an OOM RuntimeError.

On my 32 GB v100 GPU, the following code illustrates the problem:

import jax
import jaxlib
from jax import jit, vmap
import jax.numpy as jnp
import numpy as onp

jax.config.update('jax_enable_x64', True)  # allows use of float64

print(f'{jax.devices()[0].device_kind = }')
print(f'{jax.__version__ = }, {jaxlib.__version__ = }')

s = jnp.ones(2**31 // onp.float64(0).itemsize) # make a 2 GB array
t = jnp.ones(s.shape[0]+1) # make a 2 GB + 8 B array

# Works fine
try:
    print(f'Attempting {s.shape = }, {s.nbytes = }, {s.dtype = }')
    jnp.fft.fft(s).block_until_ready()
    print('FFT on 2GB array succeeded')
except RuntimeError as e:
    print('2 GB Failed.')

# coredumps:
try:
    print(f'Attempting {t.shape = }, {t.nbytes = }, {t.dtype = }')
    jnp.fft.fft(t).block_until_ready()
    print('FFT on 2 GB + 8 B array succeeded')
except RuntimeError as e:
    print('FFT on 2 GB + 8 B array failed')

print('Finished all tests')  # never reached

This results in the following output for me:

jax.devices()[0].device_kind = 'Tesla V100-DGXS-32GB'
jax.__version__ = '0.3.0', jaxlib.__version__ = '0.3.0'
Attempting s.shape = (268435456,), s.nbytes = 2147483648, s.dtype = dtype('float64')
FFT on 2GB array succeeded
Attempting t.shape = (268435457,), t.nbytes = 2147483656, t.dtype = dtype('float64')
2022-02-16 07:43:08.806443: F external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_fft.cc:439] failed to initialize batched cufft plan with customized allocator: Failed to make cuFFT batched plan.
Aborted (core dumped)

Also: even if this is fixed to result in a RuntimeError rather than a coredump, is there some way around this relatively small size limitation? I am using jnp.fft.rfft2 to perform large image convolutions and I often bump up against this.

Thanks in advance!

hawkinsp commented 2 years ago

Thanks for the report, this looks to be an XLA issue. I filed an internal bug for the XLA folks (Google bug b/219941181).

xidulu commented 2 years ago

I met this issue as well!

I am able to "fix" this issue by adding:

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

to the head of the file.

xidulu commented 2 years ago

@hawkinsp Do you have any updates on this issue.?

etienne-thuillier commented 2 years ago

Hi,

I am blocked by this exact issue using

The error is reproduced below.

I am not familiar with google's ticketing and could not find ticket b/219941181. Has this issue been resolved?

With thanks!


2022-07-06 22:48:01.859405: F external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_fft.cc:439] failed to initialize batched cufft plan with customized allocator: Allocating 5945425920 bytes exceeds the memory limit of 4294967296 bytes.

NB: 4294967296 bytes == 4 GB

hawkinsp commented 2 years ago

There has been some progress on this issue but it isn't completely fixed. I think the main restriction is that for large FFTs the FFT length must be factorizable into primes smaller than 127, which is a limitation we inherit from https://docs.nvidia.com/cuda/cufft/index.html#function-cufftmakeplanmany64

Other than that it's possible things just work now with a jaxlib built from head.