pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.09k stars 227 forks source link

Fixes `random_flax_module` with `flax.linen.BatchNorm` #1823

Closed juanitorduz closed 1 week ago

juanitorduz commented 1 week ago

Fixes https://github.com/pyro-ppl/numpyro/issues/1446

juanitorduz commented 1 week ago

The first two commits are from an old branch 🤦 . We can squash and merge instead.

juanitorduz commented 1 week ago

image

juanitorduz commented 1 week ago

@fehiepsi should I leave _substitute_default_key in utils.py or is handlers.py a better place?

fehiepsi commented 1 week ago

Leaving it in utils sounds reasonable to me. It is just a workaround for the edge case.

Thanks for fixing the issue!