Open albertfgu opened 2 years ago
These issues go away with different versions of the kernel. For example, replacing
cauchy_dot = lambda _omega: (1. / (_omega - lambd)).sum()
with
cauchy_dot = lambda _omega: (_omega - lambd).sum()
or even
cauchy_dot = lambda _omega: (lambd/np.exp(_omega * lambd)).sum()
I wondered if it was related specifically to taking powers (there is a power of -2 in the backward pass with the original version), e.g. https://github.com/deepmind/optax/issues/196#issuecomment-975666744
I have a model where the core computational kernel is a broadcast-reduce pattern that implicitly requires a large tensor, but XLA usually optimizes it away to not materialize the tensor. However, an innocuous change in the parameter initialization in Flax causes the memory to blow up.
This is the core computation:
I'll use
N
to denote the length oflambd
, andL
to denote the length ofomega
. Note that this function implicitly produces an intermediate tensor (1/(omega-lambd)
) of shape(L, N)
, but this theoretically does not need to be materialized.We wrap it very thinly in a Flax Module:
We then
vmap
it to include an extra "channel" dimension called C.After this vmap, the call to
cauchy
takes a parameter of shape(N, C)
to an output of shape(L, C)
while implicitly defining a tensor of shape(L, N, C)
. (This intermediate tensor can be confirmed by inspecting the jaxpr.) However, this tensor does not need to be materialized, which is reflected in the memory usage reported below.The following completes the minimal working example by feeding in dummy inputs to this Module:
Correct Behavior:
I ran the script with
XLA_PYTHON_CLIENT_PREALLOCATE=false python -m s4.test2 -B 32 -C 256 -L 1024 -N 1024
on an A100 GPU. As expected, the GPU memory is constant across values of N (from 64 to 1024), around 1.3-1.4Gb for me. This indicates that the intermediate tensor of shape(L, N, C)
is not being materialized.Strange memory issues:
However, several small changes make the memory blow up for the above command.
jax.jit
decorator oncauchy
causes memory usage to go up to 5.5Gb without changing the speed. Note that the entiretrain_step
is wrapped in a jit call, socauchy
should still be jit'd anyways. I'm not sure why its speed is correct but the memory is high when the decorator is removed (this might not be a bug but just something I don't understand about JIT)The thing that I think is a bug is the following:
self.x
andself.y
parameter definitions with the four commented lines, causes the memory usage to spike to 5.7Gb (and grows with N). These initializations should clearly be equivalent...So somehow, this tiny change seems to break the XLA compiler and cause it to materialize this
(L, N, C)
tensor again. Furthermore, starting from this change that has memory blowup, the following changes all lower the memory back to 1.3Gb:1*self.y
instead ofself.x*self.y
into thecauchy
callstate = state.apply_gradients(grads=grads)
(but still callinggrad_fn
!)nn.vmap
that adds a batch dimension. Note that the input tensor's memory (shape(B, L, C)
) is dominated by the intermediate(L, N, C)
tensorNone of these really make sense to me since they don't affect the main potential memory bottleneck. I'm a relative JAX novice so maybe some of these things are expected, but I'm having trouble figuring it out. My understanding is that it seems like some precise combination of irrelevant changes is causing the XLA compiler to be unable to optimize this code.