Open sbodenstein opened 1 month ago
Thanks for bringing this up. It is good to learn these new options of algorithm
and transpose_algorithm
. As for the jnp.einsum
support, can we simply use its _dot_general
arg and pass a wrapped lax.dot_general
with algorithm
and transpose_algorithm
set? like
dot_general = partial(lax.dot_general, algorithm="BF16_BF16_F32", transpose_algorithm=("BF16_BF16_BF16", "BF16_BF16_BF16"))
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key, _dot_general=dot_general)
Emm, it seems the algorithm
and transpose_algorithm
are no longer there in the lax_dot. @sbodenstein Do you know if they are reverted?
It seems that DotAlgorithm
is now an option for precision
, so I’ll try that instead.
By the way, when I experimented with DotAlgorithm
last week, I encountered errors like this:
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Failed to serialize StableHLO;
Detailed error from MLIR: /home/tmp/jax_sdpa_bwd_precision/dot_demo.py:10:0: error: failed to legalize operation
'vhlo.dot_general_v2' that was explicitly marked illegal
I suspect this might be due to the feature only working when the PJRT C API is not used, but our container has it enabled by default. I have limited knowledge of the PJRT C API, so I’m wondering:
We changed the design, which just landed (https://github.com/jax-ml/jax/pull/24079). This should now work with einsum! I'm not sure about the issues you are seeing.
Thx for the pointer. I will take a look.
Basically, I saw all the tests are protected by this check. So, I want to know more about that. I can also ask in the gchat.
@kaixih — Sorry I didn't see this conversation yesterday!
Yeah, unfortunately this feature (the new DotAlgorithm
) doesn't currently work with the PJRT plugin 😢. The issue is that, for compatibility reasons, the plugin serializes the HLO to an old dialect that doesn't support the algorithm argument to the dot_general
operation. I think that @matthiaskramm is working on adding support for targeting more recent HLO dialects via the plugin, but I'm not sure what the status is.
Also looping in @skye and @GleasonK who have been thinking about all this. No action needed; just an example of one potentially high impact use of the versioning work!
I think we've figured out another workaround to fix this. The PR is just created: https://github.com/jax-ml/jax/pull/24352.
@sbodenstein can you take a look^^?
Description
jax.nn.dot_product_attention
does the first dot product withpreferred_element_type=jnp.float32
(see here). For BF16 inputs, this prevents an unnecessary downcast to BF16 (can improve numerical stability, and has no extra memory usage for FlashAttention) + matches cuDNN numerics. However, in the backward pass, XLA now sees a matmul with FP32 inputs. This leads to a number of inconsistencies:jax.default_matmul_precision.html
, which would cause further confusion as these have no impact on the cuDNN implementation). This has half the flops a user might expect, so is fundamentally less performant than expected.TPU will also do matmuls in BF16, consistent with cuDNN.
jax.nn.dot_product_attention
with BF16/FP16 inputs should have device and implementation independent numerics. It should choose the TPU + cuDNN convention. This should be implemented via the newalgorithm
andtranspose_algorithm
args (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.dot.html#jax.lax.dot) once they are available injnp.einsum
.To explicitly see the FP32 matmuls in the backward pass:
@dfm, @kaixih, @phaw
System info (python version, jaxlib version, accelerator, etc.)
Not applicable.