bilby-dev / bilby

A unified framework for stochastic sampling packages and gravitational-wave inference in Python.
https://bilby-dev.github.io/bilby/
MIT License
66 stars 75 forks source link

Optimized rescale methods #850

Closed JasperMartins closed 1 week ago

JasperMartins commented 3 weeks ago

I have noticed that for larger numbers of samples, PriorDict.rescale becomes quite slow due to the flatten operation, which iterates over all entries with a (slow) native python for loop.

The following PR provides a relatively simple fix that should be able to handle anything rescale methods can reasonably throw at it. In my testing, for only one sample, the new version is roughly equivalent to the old version (if anything slightly faster already). For larger counts, the new method is significantly faster.

On a related note, would it not make sense to let the return value have the appropriate shape for rescales of more than one sample?

ColmTalbot commented 2 weeks ago

@mattpitkin IIRC you added the flatten to this. I recently came across this for another reason and would be glad to get rid of the flatten. I agree that it would be nice to return the correct shape, it might break some things, but I think it's worth trying and if nothing in the test suite breaks then I think it'll be quite painless to get in.

JasperMartins commented 2 weeks ago

I pondered applying a larger scale change that could remove the necessity of the flatten entirely. As far as I can tell, the flatten is necessary because JointPriors return empty lists until all parameters have been requested....actually, I wonder: if the order of the requested keys in rescale is not correct, ie something like keys = (JointPrior:a, SomeOtherPrior,JointPrior:b), the rescaling would return the order keys = (SomeOtherPrior, JointPrior:a, JointPrior:b), right?

In any case, what one could do is keep track of the object returned by JointPrior.rescale and update it in-place once all keys have been requested. This could be done with either fully initialized numpy arrays or with python lists.