Open ksmdnl opened 1 month ago
tl;dr - You should jit the init
function to get a version of it that (1) uses as little memory as possible (2) runs quickly.
To add a bit more detail, jax.jit
(through XLA) applies a number of optimizations to your program. Some of these might reduce the overall peak memory footprint required by the init
program. For example, one optimization that XLA does is limiting the live range of arrays.
Lets consider the following JAX program:
def fn():
a = some_big_array()
b = other_big_array()
c = a + b
d = yet_another_big_array()
e = c + d
return e
XLA would be able to notice that a
and b
can be safely deleted (and as such their GPU memory woudl be freed) before you compute d
:
def fn():
a = some_big_array()
b = other_big_array()
c = a + b
# XLA knows a/b aren't used again so it can release GPU memory for them.
free_gpu_memory(a)
free_gpu_memory(b)
d = yet_another_big_array()
e = c + d
return e
If you wanted to debug further and figure out which arrays were still hanging around causing the 50GB allocation to fail, then JAX allows you to see a traceback for where arrays were created, which might help you understand this in more depth. To help you get started, something like the following might work:
import traceback
def print_live_arrays():
for array in jax.live_arrays():
print(array.shape, array.dtype)
traceback.print_tb(array.traceback.as_python_traceback())
print()
try:
benchmark()
except RuntimeError as e:
if 'RESOURCE_EXHAUSTED' in e:
print_live_arrays()
raise e
That said, even if you knew the root cause (e.g. which arrays were hanging around) the recommended fix is the same: use jax.jit
and let XLA optimize this for you.
I'm currently doing a runtime analysis of the attention matrix of a transformer. Specifically, I'd like to know how the time complexity behaves w.r.t. to the size of the attention matrix.
For
nb_node = 414
I'm getting an OOM error in the initialization (when performing the Einstein summation), which looks as follows:This is quite strange, since I'm using A100 80GB. However, when one
jit
the init funtion, there is no OOM error even afternb_node = 500
. My question is, would this be a correct workaround given in this case?