Closed anvvalade closed 3 days ago
Interesting. @modichirag told me about the same error message on slack. Thanks for the example. I wonder if this is another problem that comes and goes with JAX updates. I am busy with job stuff right now, and will take a look, in 2 weeks at the latest.
Thanks for the quick answer. May your job stuff go smoothly, so I don't spend too long in front of my computer frantically pressing F5 ! ;)
@eelregit, have you had a chance to look at the issue? I have deadlines coming up and would love to present work based on your code (it'd be a great advertisement for us all ;) )
Ok, I should have taken a closer look 2 weeks ago. Two things:
jit(grad())
instead of the other waySo could you simply try removing the jit in your current blackjax pipeline?
No problem, you couldn't know!
lax.scan
function, right? Would that solve the issue?I was trying to use tensorflow-probability.mcmc.nuts
but it seems that using the NUTS of blackjax
solves the issue. I'll try further.
As to removing the jit
, it seems like a deal breaker on my end, I need the jitting
inside and outside of the posterior pdf. Too long without. I could try to jit
parts manually but it'd probably imply tempering with the Monte Carlo libraries if I want the integration to be jitted too.... Bad idea!
Do I understand correctly that beside these work-around, you do not know exactly what causes the bug?
Will probably use while from lax in the future, scan won't work for some features.
I don't think the problem is from pmwd, because 'int16[64,3]' is a jax repr. I think they didn't expect the grad(jit()) order and don't have a test for it. In general I cannot think of a use case like that, so maybe everything is fine now, right?
BTW I'm surprised that HMC computations themselves are enough to dominate without jitting. They shouldn't be that heavy, right?
Sorry for the delay -- needed to run some tests before coming back to you.
Turns out you were right (thanks) the issue was not with TensorFlow, but that I had a jit
decorator somewhere that I just needed to remove to avoid the crash. Thanks a lot for that. I am now running in other, funnier crashes.
You are also probably right that the HMC computation themselves are not too heavy, especially when nbody is implied. But it'd mean I'd also have to manually jit
every other function of the code, and there are quite some!
I'll mark this issue as solved! Thanks.
Thanks for letting me know. Do you still have crashes if you do jit(grad(posterior(...)))
?
And happy to take a look if you can elaborate (or share code about) which jit caused the problem and what the funnier crashes are.
BTW, there was a Omega_m gradient problem that came last year with a JAX update and went away with another sometime this year. I never figured out why but you can use the latest version of JAX to avoid that.
Cannot evaluate
jax.vjp
on the jitted version ofnbody
, gettingHere is a minimal example to reproduce the error:
Running on python 3.9, with:
Full error output:
Some printing from
Particles.__post__init__
andtree_util.unflatten
, last loop before crashing and the loop that crashes: