jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.54k stars 2.8k forks source link

dot_product_attention has inconsistent dot types in backward pass #24047

Open sbodenstein opened 1 month ago

sbodenstein commented 1 month ago

Description

jax.nn.dot_product_attention does the first dot product with preferred_element_type=jnp.float32 (see here). For BF16 inputs, this prevents an unnecessary downcast to BF16 (can improve numerical stability, and has no extra memory usage for FlashAttention) + matches cuDNN numerics. However, in the backward pass, XLA now sees a matmul with FP32 inputs. This leads to a number of inconsistencies:

To explicitly see the FP32 matmuls in the backward pass:

import jax
import jax.numpy as jnp

def loss(x):
  return jnp.sum(jax.nn.dot_product_attention(x, x, x, implementation='xla'))

x = jax.random.normal(jax.random.PRNGKey(0), (8, 128, 8, 64), dtype=jnp.bfloat16)
f_grad = jax.jit(jax.grad(loss))

print(f_grad.lower(x).compile().as_text())

@dfm, @kaixih, @phaw

System info (python version, jaxlib version, accelerator, etc.)

Not applicable.

kaixih commented 1 month ago

Thanks for bringing this up. It is good to learn these new options of algorithm and transpose_algorithm. As for the jnp.einsum support, can we simply use its _dot_general arg and pass a wrapped lax.dot_general with algorithm and transpose_algorithm set? like

dot_general = partial(lax.dot_general, algorithm="BF16_BF16_F32", transpose_algorithm=("BF16_BF16_BF16", "BF16_BF16_BF16"))
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key, _dot_general=dot_general)
kaixih commented 1 month ago

Emm, it seems the algorithm and transpose_algorithm are no longer there in the lax_dot. @sbodenstein Do you know if they are reverted?

kaixih commented 1 month ago

It seems that DotAlgorithm is now an option for precision, so I’ll try that instead.

By the way, when I experimented with DotAlgorithm last week, I encountered errors like this:

jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Failed to serialize StableHLO;
Detailed error from MLIR: /home/tmp/jax_sdpa_bwd_precision/dot_demo.py:10:0: error: failed to legalize operation
 'vhlo.dot_general_v2' that was explicitly marked illegal

I suspect this might be due to the feature only working when the PJRT C API is not used, but our container has it enabled by default. I have limited knowledge of the PJRT C API, so I’m wondering:

sbodenstein commented 1 month ago

We changed the design, which just landed (https://github.com/jax-ml/jax/pull/24079). This should now work with einsum! I'm not sure about the issues you are seeing.

kaixih commented 1 month ago

Thx for the pointer. I will take a look.

Basically, I saw all the tests are protected by this check. So, I want to know more about that. I can also ask in the gchat.

dfm commented 1 month ago

@kaixih — Sorry I didn't see this conversation yesterday!

Yeah, unfortunately this feature (the new DotAlgorithm) doesn't currently work with the PJRT plugin 😢. The issue is that, for compatibility reasons, the plugin serializes the HLO to an old dialect that doesn't support the algorithm argument to the dot_general operation. I think that @matthiaskramm is working on adding support for targeting more recent HLO dialects via the plugin, but I'm not sure what the status is.

Also looping in @skye and @GleasonK who have been thinking about all this. No action needed; just an example of one potentially high impact use of the versioning work!

kaixih commented 1 month ago

I think we've figured out another workaround to fix this. The PR is just created: https://github.com/jax-ml/jax/pull/24352.

@sbodenstein can you take a look^^?