Open lsy323 opened 1 week ago
https://github.com/jax-ml/jax/pull/24608 should fix this by calculating the cost estimate directly instead of via compiling a reference function. .
In the meantime you could also just remove the cost-estimate temporarily as it won't affect correctness of the code.
Description
With
0.4.35
release, the flash attention kernel will hit compilation OOM for long sequence length inputs. It failed during compiling the reference attention implementation for cost analysis purpose, added in this commit: https://github.com/jax-ml/jax/commit/4c218fbf3b8431a5f75cdf20942d5d62433a8657.This logic would cause OOM when the device HBM is not enough for the naive attention but fits for flash attention kernels.
Here is a small repro script for this issue (on v5e TPU with 16GB HBM)
The error message is:
It failed with
0.4.35
but passed with20240829
nightlycc @WoosukKwon who surfaced and root caused the issue in the first place.
System info (python version, jaxlib version, accelerator, etc.)
Version with the compilation OOM issue
Version without the compilation OOM issue