Open jasonli0707 opened 2 years ago
Thanks for the report, your code correct!
This actually looks like two bugs on our side:
1) store_on_device
argument isn't working, and the kernel is stored on the GPU (I'm assuming you have enough CPU RAM, so you're not running out of it).
2) even if store_on_device=True
, 24Gb of GPU RAM should be enough for the 50k x 50k kernel, but somehow it's not. I suspect there might be a conflict with JAX and Tensorflow competing for GPU memory, could you try running this version of the code on your machine?
https://colab.research.google.com/gist/romanngg/96421af87f4cc1e13a78454d8bfb4ee9/memory_repro.ipynb
The part that hopefully helps is
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
import tensorflow_datasets as tfds
(and I'm using tfds
instead of neural_tangents.examples
)
Another idea is to binary search smaller training set sizes to figure out if we're really hitting the memory limit (e.g. it works for 40K, but not 50K), or if the GPU memory is just not available for some reason (e.g. it doesn't work even for 5K).
Also, could you please post the whole error message?
Thank you so much for the detailed reply!
I have tried your code but still face the same issue. Below shows the complete error message for your reference:
2022-09-08 13:20:36.044808: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.31GiB (rounded to 10000000000)requested by op 2022-09-08 13:20:36.044942: W external/org_tensorflow/tensorflow/core/common_runtime/bfcallocator.cc:491] *** 2022-09-08 13:20:36.045005: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2130] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 10000000000 bytes. BufferAssignment OOM Debugging. BufferAssignment stats: parameter allocation: 9.31GiB constant allocation: 0B maybe_live_out allocation: 9.31GiB preallocated temp allocation: 0B total allocation: 18.63GiB total fragmentation: 0B (0.00%) Peak buffers: Buffer 1: Size: 9.31GiB Entry Parameter Subshape: s32[50000,50000]
Buffer 2:
Size: 9.31GiB
Operator: op_name="jit(add)/jit(main)/add" source_file="/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py" source_line=1222
XLA Label: fusion
Shape: s32[50000,50000]
Buffer 3:
Size: 4B
Entry Parameter Subshape: s32[]
Traceback (most recent call last):
File "mnist.py", line 68, in
Buffer 2:
Size: 9.31GiB
Operator: op_name="jit(add)/jit(main)/add" source_file="/home/jason/dev/neural-tangents/neural_tangents/_src/predict.py" source_line=1222
XLA Label: fusion
Shape: s32[50000,50000]
Buffer 3:
Size: 4B
Entry Parameter Subshape: s32[]
I have also tried searching for the maximum number of samples before encountering the memory issue, which turned out to be 36000 in my case:
num_samples = 36000
x_train = x_train[:num_samples]
y_train = y_train[:num_samples]
Oh thanks for the error message, I realized what's actually failing is
fx_train_inf, fx_test_inf = predict_fn(fx_train_0=fx_train_0, fx_test_0=fx_test_0, k_test_train=k_test_train)
and not the kernel computation. Indeed 24Gb is not enough to run the Cholesky solver on the 50k x 50k matrix, so you'd need to be doing it on CPU.
To make it happen on CPU, I think the easiest way should be to have predict_fn = jit(predict_fn, backend='cpu')
after you define it (and good to jit this function anyway).
Alternatively, but hopefully not necessarily, you can pin input tensors to CPU, to make sure the function called with them as inputs is executed on CPU:
fx_train_0 = jax.device_put(fx_train_0, devices('cpu')[0])
fx_test_0 = jax.device_put(fx_test_0, devices('cpu')[0])
k_test_train = jax.device_put(k_test_train, devices('cpu')[0])
and/or
k_train_train = jax.device_put(k_train_train, devices('cpu')[0])
y_train = jax.device_put(y_train, devices('cpu')[0])
before defining predict_fn
. In general, you can print x.device_buffer.device()
in various places to see which tensors x
are stored on which devices, to figure out what is happening on CPU/GPU (you want your last line to be executed on CPU).
Thank you so much for the detailed follow-up!
As you suggested, I have tried to move everything to the CPU before defining the predict_fn
and verified that they were indeed stored on the CPU. However, after a few minutes, the program is killed by the signal SIGSEGV (Address boundary error). Does it mean that I'm also out of CPU RAM? If yes, is there anything that I can do?
How much RAM do you have? Does it work (on CPU, after your modifications) if you use 36k points? I suspect you'd need at least ~64 Gb of RAM, but I only ever tried it on a machine with >128Gb, so I'm not sure what is the exact requirement.
To better debug this you can try to run the piece of code from https://github.com/google/neural-tangents/issues/152#issuecomment-1121615513 using first numpy/scipy, and then jax.numpy and jax.scipy to have a smaller repro. Then you could post it to https://github.com/google/jax and ask what they think. I also occasionally get these low-level errors when doing level-3 algebra on large matrices, and don't know how to debug them myself... (e.g. https://github.com/google/jax/issues/10411, https://github.com/google/jax/issues/10420)
I am working on a simple MNIST example. I found that I could not compute the NTK for the entire dataset without running out of memory. Below is the code snippet I used:
I am running this on two RTX3090 each having a 24Gb buffer. Is there something I'm doing wrong, or is it normal for NTK to consume so much memory? Thank you!