jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.63k stars 2.82k forks source link

Add float8_e8m0fnu type support #25116

Open wenscarl opened 5 days ago

wenscarl commented 5 days ago

This PR adds E8M0fnu type support. E8M0fnu is a OpenCompute MX scale format, which has the following properties:

Unsigned format 8 exponent bits Exponent range from -127 to 127 No zero and infinity Single NaN value (0xFF).

@jakevdp

Smoke test

import jax
import jax.numpy as jnp

def foo(a):
   return jax.lax.bitcast_convert_type(a, new_dtype=jnp.float8_e8m0fnu)
a = jnp.ones((2,2),dtype=jnp.float8_e4m3fn)
foo_jit = jax.jit(foo)

# StableHLO
print(foo_jit.lower(a).as_text("stablehlo"))
# HLO
print(foo_jit.lower(a).as_text("hlo"))
# HLO
print(foo_jit.lower(a).compile().as_text())

Seeing error


jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Failed to serialize StableHLO;

Detailed error from MLIR: <unknown>:0: error: failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal
<unknown>:0: note: see current operation:
"vhlo.func_v1"() <{arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{}>]>, function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.tensor_v1<256x4096x1024x!vhlo.f8E4M3FN_v1>) -> !vhlo.tensor_v1<256x4096x1024x!vhlo.f8E8M0FNU_v1>>>, res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"jax.result_info"> = #vhlo.string_v1<"">}>]>, sym_name = #vhlo.string_v1<"main">, sym_visibility = #vhlo.string_v1<"public">}> ({
^bb0(%arg0: !vhlo.tensor_v1<256x4096x1024x!vhlo.f8E4M3FN_v1>):
  %0 = "vhlo.bitcast_convert_v1"(%arg0) : (!vhlo.tensor_v1<256x4096x1024x!vhlo.f8E4M3FN_v1>) -> !vhlo.tensor_v1<256x4096x1024x!vhlo.f8E8M0FNU_v1>
  "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<256x4096x1024x!vhlo.f8E8M0FNU_v1>) -> ()
}) : () -> ()