Closed NeilGirdhar closed 1 year ago
We're going to need some sort of repro, I think.
@hawkinsp Okay, I'll invest the time into that. It means starting with my 8k line program, and chopping it page by page until I have a MWE.
After significantly rewriting my code, I can no longer reproduce this.
I think it may have been caused by a custom JVP returning a cotangent with a different pytree structure than the primal. For some reason, Jax isn't catching the structure mismatch, and I had to add some assertions.
Description
I don't have a MWE, but I can give someone access to my repository if it helps. What I do know is adding a stop-gradient eliminates the crash, so it appears that the cotangent has no shape.
What jax/jaxlib version are you using?
Jax master (0.3.25), jaxlib 0.3.24
Which accelerator(s) are you using?
CPU