jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30k stars 2.75k forks source link

vjp + property + disable_jit results in wrong dtype #9698

Closed eelregit closed 1 year ago

eelregit commented 2 years ago
from typing import NamedTuple

import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as jnp

class XYZ(NamedTuple):
    x: float
    y: float

    @property
    def z(self):
        return self.x + self.y

def fz(a, xyz):
    return a.sum() * xyz.z

def fy(a, xyz):
    return a.sum() * xyz.y

a = jnp.zeros(5, dtype='f4')
xyz = XYZ(0., 1.)

# properties disagree on dtypes

print('fz vjp')
print(jax.vjp(fz, a, xyz)[0].dtype)

with jax.disable_jit():
    print(jax.vjp(fz, a, xyz)[0].dtype)

# no problem with the attributes

print('\nfy vjp')
print(jax.vjp(fy, a, xyz)[0].dtype)

with jax.disable_jit():
    print(jax.vjp(fy, a, xyz)[0].dtype)

# no problem when using grad

print('\nfz grad')
print(jax.grad(fy)(a, xyz).dtype)

with jax.disable_jit():
    print(jax.grad(fy)(a, xyz).dtype)

returns

fz vjp
float32
float64

fy vjp
float32
float32

fz grad
float32
float32
jakevdp commented 2 years ago

Thanks for the report. I suspect the issue is that the weak type is stripped from the input within the constant handler. This is one of the hairy corner cases of JAX's approach to type promotion; I've been working on making the package more consistent in this respect, but we're not there yet.

Thanks for the nice repo and test case!

jakevdp commented 1 year ago

This seems to have been fixed somewhere along the way. Here's the output with JAX 0.4.11:

fz vjp
float32
float32

fy vjp
float32
float32

fz grad
float32
float32