jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.1k stars 2.76k forks source link

Multiplying Nan by False give 0. instead of NaN #12233

Open sracaniere opened 2 years ago

sracaniere commented 2 years ago

Description

Multiplying a NaN by False in jax gives a 0. See example below:

nan = jnp.array(float('nan'), dtype=jnp.float32)
print(nan)
print(nan * jnp.array(0., dtype=jnp.float32))
print(nan * jnp.array(False, dtype=jnp.bool_))

Output:

nan
nan
0.0

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

CPU

Additional System Info

Linux

sracaniere commented 2 years ago

Update: if you replace float32 by bfloat16 in the above code, the multiplication by False gives NaN as expected.

jakevdp commented 2 years ago

Thanks for the report! This is really peculiar... I'm looking into it now.

jakevdp commented 2 years ago

This only occurs within JIT, so I suspect this is coming from some XLA simplification where it replaces multiplication by zero with zero:

import jax.numpy as jnp
import jax

def f(x, y):  # This is basically what `jnp.mul` does internally:
  y = jax.lax.convert_element_type(y, float)
  return jax.lax.mul(x, y)

print(jax.make_jaxpr(f)(jnp.nan, False))
# { lambda ; a:f32[] b:bool[]. let
#     c:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
#     d:f32[] = mul a c
#   in (d,) }

print(f(jnp.nan, False))
# nan

print(jax.jit(f)(jnp.nan, False))
# 0.0

However, the thing I don't understand is that if you replace False with 0 or 0.0, the result correctly returns nan.

jakevdp commented 2 years ago

Here's the difference between how XLA treats integer vs boolean zero. In the case of multiplication with a constant boolean, XLA replaces the mutiplication by a select statement, which is correct for every finite floating point value, but will return the wrong value for nan and inf.

print(jax.jit(f).lower(jnp.nan, False).compile().as_text())
HloModule jit_f.3, entry_computation_layout={(f32[],pred[])->f32[]}

ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: pred[]) -> f32[] {
  %Arg_1.2 = pred[] parameter(1)
  %Arg_0.1 = f32[] parameter(0)
  %constant.1 = f32[] constant(0)
  ROOT %select = f32[] select(pred[] %Arg_1.2, f32[] %Arg_0.1, f32[] %constant.1), metadata={op_name="jit(f)/jit(main)/mul" source_file="tmp.py" source_line=6}
}

For integer or float input, the compiled code actually uses multiplication, and so it returns the expected result:

print(jax.jit(f).lower(jnp.nan, 0).compile().as_text())
HloModule jit_f.4, entry_computation_layout={(f32[],s32[])->f32[]}

%fused_computation (param_0: f32[], param_1.1: s32[]) -> f32[] {
  %param_0 = f32[] parameter(0)
  %param_1.1 = s32[] parameter(1)
  %convert.0 = f32[] convert(s32[] %param_1.1), metadata={op_name="jit(f)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]" source_file="tmp.py" source_line=5}
  ROOT %multiply.0 = f32[] multiply(f32[] %param_0, f32[] %convert.0), metadata={op_name="jit(f)/jit(main)/mul" source_file="tmp.py" source_line=6}
}

ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: s32[]) -> f32[] {
  %Arg_0.1 = f32[] parameter(0)
  %Arg_1.2 = s32[] parameter(1)
  ROOT %fusion = f32[] fusion(f32[] %Arg_0.1, s32[] %Arg_1.2), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(f)/jit(main)/mul" source_file="tmp.py" source_line=6}
}
jakevdp commented 2 years ago

Internal tracking: b/245348010