eelregit / pmwd

Differentiable Cosmological Forward Model
BSD 3-Clause "New" or "Revised" License
69 stars 16 forks source link

Jax bug in linear_modes #27

Open maho3 opened 6 months ago

maho3 commented 6 months ago

On the master branch, this line causes the following error:

  File "/home/mattho/git/ltu-cmass/cmass/nbody/pmwd.py", line 74, in run_density
    ic = linear_modes(wn, pmcosmo, pmconf)
ValueError: the `static_argnums` argument to `jax.checkpoint` / `jax.remat` can only take integer values greater than or equal to `-len(args)` and less than `len(args)`, but got (4,)

It seems @eelregit pointed this out as a jax bug, but it hasn't been resolved yet.

Could there be a work around temporarily pushed to master? As this is currently breaking the master branch.

eelregit commented 6 months ago

Hi Matt,

Thanks for pointing out that JAX issue. And sorry for the delay. Have you tried to use the positional argument workaround pointed out there?

maho3 commented 6 months ago

I actually ended up implementing the same hack used in your sto branch and it works for me.

If you want, I can make a PR for this small thing to be put on master.