Closed eelregit closed 1 year ago
Thanks for the report. I suspect the issue is that the weak type is stripped from the input within the constant handler. This is one of the hairy corner cases of JAX's approach to type promotion; I've been working on making the package more consistent in this respect, but we're not there yet.
Thanks for the nice repo and test case!
This seems to have been fixed somewhere along the way. Here's the output with JAX 0.4.11:
fz vjp
float32
float32
fy vjp
float32
float32
fz grad
float32
float32
returns