Closed zwelitunyiswa closed 8 months ago
Gosh, you're not going to believe this, but the problem is jax.random.key(0)
instead of jax.random.PRNGKey(0)
. The easiest thing for you is to change the type of jax key you use for now, but it is fixed on tfp-nightly
, and will be in the next stable release (here is the tricky fix from @SiegeLordEx)
Yes, changing it to "jax.random.PRNGKey(0)" works! Thank you!
When I run the following, I get the following error:
The traceback is as follows: