pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 235 forks source link

Address google/jax#19885 for numpyro. #1743

Closed tillahoffmann closed 7 months ago

tillahoffmann commented 7 months ago

This should address performance issues for LowRankMultivariateNormal distributions with batch dimensions. Only a single change was required to fix the issue. There are other [matmul] + [identity] expressions in the codebase, but, for some reason, they don't cause any issues. The tests verify that no warning is emitted.