Closed albertfgu closed 2 years ago
This note is very helpful. During debugging I graphed this function and did find that in isolation it was removed the factor of N as well, so this is very perplexing. It seems plausible that either:
a) the compiler is not finding it in the full code, b) there is some other O(HNL) memory term that we are missing? c) this term comes up in backprop? (although I think your experiment eliminates that possiblity)
I will raise this on the JAX discussions and ask people at HF. One thing I was recommended last time is that you can actually inspect the generated code https://jax.readthedocs.io/en/latest/jaxpr.html . This is a bit of work, but might be interesting.
Copied to https://github.com/albertfgu/annotated-s4/pull/9 so other PRs can be stacked in my fork before merging here.
JAX bug filed at https://github.com/google/jax/issues/11007
I dug into the memory issue here.
The memory issue should arise from the Cauchy computation with signature
(L), (N) -> (L)
and is essentially this computation:This takes LN memory with the above code, and LNH memory when broadcast over the H dimension. However, the current implementation considers only the N dimension and is later vmap'd for the L and H dimensions. Supposedly XLA should be smart enough to optimize away the intermediate memory: the memory usage should ideally be O(LH+NH) instead of O(LNH). Note that this is independent of batch size. Meanwhile, the activations have memory O(BLH), which is independent of N. So if the memory is optimized properly, the Cauchy kernel should not be noticeable as N scales, because (L+N)H << BLH. For reference, the PyTorch S4 repo uses 2217/2237Gb for N = 64/256 with the equivalent model as the below command.
To confirm the issue, I ran models of different N dimensions and saw that the memory scales. Specifically, I ran the following command:
which has (L, N, H) = (1024, 64, 256) and I varied the
ssm_n
(i.e. N) between 64, 256, and 512.This used 3494/5554/9636 Gb of GPU memory for N = 64/256/512, which scales pretty much exactly as one would expect for O(LNH) memory use.
Additionally, I ran the same commands for the
dss
model and found the exact same memory usage. The DSS kernel explicitly materializes a tensor of shape (N, L), so it would use O(HNL) memory assuming naive broadcasting over the H dimension. So everything seems consistent with O(HNL) memory use.Finally, I ran the S4D model which is similar but does not materialize the (N, L) tensor:
This computes a map of shape
(N) -> (L)
like DSS butvmap
s over the L dimension. Running this model uses around 2500Gb memory independent of the size of N, as desired.This PR has my investigation into the issue.
First, I moved one of the
vmap
abstractions around in the hopes that it would guide the compiler better. The original code wrote the Cauchy kernel without an L dimension and had several intermediate abstractions like the generating function which made it a bit harder for me to follow.I implemented the Cauchy kernel directly for inputs of shape (N) and (L), in exactly the same style as the above S4D kernel. Then I combined
K_gen_DPLR
andconv_from_gen
which are always called together, into a single functionkernel_DPLR
which exposes all theN
andL
dimensions. All tensors should be of sizeN
orL
and something of shape(N, L)
is never explicitly materialized. As far as I can tell, the only place something of size HNL is possible is inside thecauchy
function. Despite thecauchy
function being so simple and written in exactly the same way as the efficient S4D kernel above, this model still uses the same amount of memory as before.Next I tried to call these functions directly in
s4.py
instead of in the main train scripttrain.py
. I basically just copied all the model and optimizer logic as closely as possible, and ranXLA_PYTHON_CLIENT_PREALLOCATE=false python -m s4.s4
to check the memory use. This was always memory efficient independently of N! I got stuck here and I have no idea what else could be going on. It's also possible I made a mistake in reproducing the train script.