Open wenscarl opened 5 days ago
This PR adds E8M0fnu type support. E8M0fnu is a OpenCompute MX scale format, which has the following properties:
Unsigned format 8 exponent bits Exponent range from -127 to 127 No zero and infinity Single NaN value (0xFF).
@jakevdp
Smoke test
import jax import jax.numpy as jnp def foo(a): return jax.lax.bitcast_convert_type(a, new_dtype=jnp.float8_e8m0fnu) a = jnp.ones((2,2),dtype=jnp.float8_e4m3fn) foo_jit = jax.jit(foo) # StableHLO print(foo_jit.lower(a).as_text("stablehlo")) # HLO print(foo_jit.lower(a).as_text("hlo")) # HLO print(foo_jit.lower(a).compile().as_text())
Seeing error
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Failed to serialize StableHLO; Detailed error from MLIR: <unknown>:0: error: failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal <unknown>:0: note: see current operation: "vhlo.func_v1"() <{arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{}>]>, function_type = #vhlo.type_v1<!vhlo.func_v1<(!vhlo.tensor_v1<256x4096x1024x!vhlo.f8E4M3FN_v1>) -> !vhlo.tensor_v1<256x4096x1024x!vhlo.f8E8M0FNU_v1>>>, res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"jax.result_info"> = #vhlo.string_v1<"">}>]>, sym_name = #vhlo.string_v1<"main">, sym_visibility = #vhlo.string_v1<"public">}> ({ ^bb0(%arg0: !vhlo.tensor_v1<256x4096x1024x!vhlo.f8E4M3FN_v1>): %0 = "vhlo.bitcast_convert_v1"(%arg0) : (!vhlo.tensor_v1<256x4096x1024x!vhlo.f8E4M3FN_v1>) -> !vhlo.tensor_v1<256x4096x1024x!vhlo.f8E8M0FNU_v1> "vhlo.return_v1"(%0) : (!vhlo.tensor_v1<256x4096x1024x!vhlo.f8E8M0FNU_v1>) -> () }) : () -> ()
This PR adds E8M0fnu type support. E8M0fnu is a OpenCompute MX scale format, which has the following properties:
Unsigned format 8 exponent bits Exponent range from -127 to 127 No zero and infinity Single NaN value (0xFF).
@jakevdp
Smoke test
Seeing error