pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 235 forks source link

Avoid unnecessary reshape for trivial expand #1776

Closed fehiepsi closed 5 months ago

fehiepsi commented 5 months ago

This PR saves unnecessary transpose, reshape operators of ExpandedDistribution's sample method. The current tests should cover the change.

This requires us to revert the order of expanded_sizes such that the dimensions are ordered from left to right (rather than the current right -> left).

fehiepsi commented 5 months ago

Thanks, Ola!