jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

Pattern match dot algorithm spec to preset name #24820

Open dfm opened 3 days ago

dfm commented 3 days ago

Since some dot algorithm presets have special cased input and output storage type behavior (e.g. all the BF16 algorithms), and since JAX's handling of these cases is (for better or worse) handled at the "preset" level, this PR provides a small quality of life improvement to convert known lax.DotAlgorithm specs to explicit lax.DotAlgorithmPreset members whenever possible. For example, if a user specifies:

precision = lax.DotAlgorithm(dtypes.bfloat16, dtypes.bfloat16, np.float32,
                             num_primitive_operations=6)

this will be canonicalized to lax.DotAlgorithmPreset.BF16_BF16_F32_X6 and the input and output casting will be handled properly.