Open sracaniere opened 2 years ago
Update: if you replace float32 by bfloat16 in the above code, the multiplication by False gives NaN as expected.
Thanks for the report! This is really peculiar... I'm looking into it now.
This only occurs within JIT, so I suspect this is coming from some XLA simplification where it replaces multiplication by zero with zero:
import jax.numpy as jnp
import jax
def f(x, y): # This is basically what `jnp.mul` does internally:
y = jax.lax.convert_element_type(y, float)
return jax.lax.mul(x, y)
print(jax.make_jaxpr(f)(jnp.nan, False))
# { lambda ; a:f32[] b:bool[]. let
# c:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
# d:f32[] = mul a c
# in (d,) }
print(f(jnp.nan, False))
# nan
print(jax.jit(f)(jnp.nan, False))
# 0.0
However, the thing I don't understand is that if you replace False
with 0
or 0.0
, the result correctly returns nan
.
Here's the difference between how XLA treats integer vs boolean zero. In the case of multiplication with a constant boolean, XLA replaces the mutiplication by a select
statement, which is correct for every finite floating point value, but will return the wrong value for nan
and inf
.
print(jax.jit(f).lower(jnp.nan, False).compile().as_text())
HloModule jit_f.3, entry_computation_layout={(f32[],pred[])->f32[]}
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: pred[]) -> f32[] {
%Arg_1.2 = pred[] parameter(1)
%Arg_0.1 = f32[] parameter(0)
%constant.1 = f32[] constant(0)
ROOT %select = f32[] select(pred[] %Arg_1.2, f32[] %Arg_0.1, f32[] %constant.1), metadata={op_name="jit(f)/jit(main)/mul" source_file="tmp.py" source_line=6}
}
For integer or float input, the compiled code actually uses multiplication, and so it returns the expected result:
print(jax.jit(f).lower(jnp.nan, 0).compile().as_text())
HloModule jit_f.4, entry_computation_layout={(f32[],s32[])->f32[]}
%fused_computation (param_0: f32[], param_1.1: s32[]) -> f32[] {
%param_0 = f32[] parameter(0)
%param_1.1 = s32[] parameter(1)
%convert.0 = f32[] convert(s32[] %param_1.1), metadata={op_name="jit(f)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="tmp.py" source_line=5}
ROOT %multiply.0 = f32[] multiply(f32[] %param_0, f32[] %convert.0), metadata={op_name="jit(f)/jit(main)/mul" source_file="tmp.py" source_line=6}
}
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: s32[]) -> f32[] {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = s32[] parameter(1)
ROOT %fusion = f32[] fusion(f32[] %Arg_0.1, s32[] %Arg_1.2), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(f)/jit(main)/mul" source_file="tmp.py" source_line=6}
}
Internal tracking: b/245348010
Description
Multiplying a NaN by False in jax gives a 0. See example below:
Output:
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
CPU
Additional System Info
Linux