flav-io / flavio

A Python package for flavour physics phenomenology in the Standard model and beyond
http://flav-io.github.io/
MIT License
71 stars 62 forks source link

Performance and possible speedup #163

Open jonas-eschle opened 3 years ago

jonas-eschle commented 3 years ago

Hi all, we are using flavio to generate and fit heavily. Since we perform many toys, speed becomes a critical issue for us and we started looking into possible ways of speeding flavio up.

Are there any plans or ideas to increase the speed of flavio or anything performance related in the code of flavio? Either on a software level or on a mathematical level such as more caching?

As we had a look at common speedups with JIT such as numba or tracing with JAX, TensorFlow, it seems that the code is not uniformly written: some parts use math, others numpy. A lot of Python boilerplate (such as a dict for a complex number) and constructs seem around that make it difficult, seemingly to obtain a simple speedup.

Are there any plans/ideas in this direction?

DavidMStraub commented 3 years ago

such as a dict for a complex number

:scream: Where?

In general, you are right that the code was not written from the start with optimization in mind (in fact, it was initially written without too much in mind at all).

Then again, we spent some effort in the past getting rid of bottlenecks by caching, vectorization, and other tricks. RG evolution is done fully in C. What is your bottleneck? In my applications, it was always numerical integrations (e.g. in q² bins) of functions that contained lots of hard-to-speed-up stuff like interpolating functions. I wouldn't know how JAX et al. would help there.

jonas-eschle commented 3 years ago

Where?

Acknowledging, I don't remember, but it was a bottleneck for JAX, will find it again, but it speaks for itself that I don't do now. So it was maybe rather an unlucky encounter instead of the norm.

Yes indeed, there is already quite some stuff around in terms of speedup, that's what makes further gains non-trivial. Good to hear that this is intentionally already searched and the bottleneck is the same I found, the integration. What JAX may helps with is to JIT the function and to remove Python overhead and maybe to obtain automatic gradients. In principle. And maybe there is speedup, but how much is indeed unclear. It's more that other things are already implemented, that's the only real speedup we could think of in terms of pure coding improvements. But as mentioned, if, that may requires a few changes.

Another thing would be mathematically by rewriting certain amplitudes, factoring out the q2 dependent part and then having a local cache @Abhijit_Mathad (he's notified by chat)

In other words, these are ideas that we had; the question is also how much could we change flavio for speed and af there are other ideas around in the back of your head.

abhijitm08 commented 3 years ago

@DavidMStraub We are talking particularly about FCNC b hadron decays. Since the slow down comes from integration, indeed as Jonas mentioned, in cases where the amplitudes are linear functions of parameters of interests (Wilson coefficients and form factors), one could factor out the phase space dependent part and cache the integrals. AFAIK, such type of caching is not currently there in Flavio (please correct me, if I am wrong). With this, the first evaluation of the observable would be very slow but this would really speed up the subsequent evaluations at the fitting stage.

dlanci commented 3 years ago

such as a dict for a complex number

😱 Where?

@DavidMStraub, I might be wrong, but everytime a Wilson Coefficient dictionary is instantiated, in the evaluation of the LogL in a WC fit isn't this what happens?

peterstangl commented 3 years ago

@DavidMStraub We are talking particularly about FCNC b hadron decays. Since the slow down comes from integration, indeed as Jonas mentioned, in cases where the amplitudes are linear functions of parameters of interests (Wilson coefficients and form factors), one could factor out the phase space dependent part and cache the integrals. AFAIK, such type of caching is not currently there in Flavio (please correct me, if I am wrong). With this, the first evaluation of the observable would be very slow but this would really speed up the subsequent evaluations at the fitting stage.

I have actually a preliminary implementation of bascially what you describe. I express observables as functions of polynomials that are quadratic in the Wilson coefficients or other parameters that enter the amplitudes linearly. The coefficients of these polynomials have to be computed only once and this is done using flavio. Such a polynomial can be expressed as a scalar product of a vector containing the polynomial coefficients and another vector containing Wilson coefficient and parameter bilinears. Computing theory predictions therefore reduces to a simple linear algebra problem that can be solved very efficiently. This approach makes it even possible to compute a covariance matrix for the polynomial coefficients of all observables in a fit and thus to extend flavio's FastLikelihood by including theory uncertainties that actually depend on the new physics Wilson coefficients (this was used in https://arxiv.org/abs/2103.13370).

I am planning to add this functionality to flavio at some point. Unfortunately I do not have enough time at the moment to implement it and there are also some other issues in flavio, like bug fixes, that have a higher priority.

DavidMStraub commented 3 years ago

@mayou36 @abhijitm08 just to clarify, I'm not actively contriuting to flavio anymore, I just added my 2 cents since I'm the most likely culprit for any poor design decisions in the early development. If I had had automatic differentiation in mind, maybe some things would be a bit different. Then again, one of the main goals of flavio in the early days was to be easily extensible, e.g. to facilitate comparison between different form factor implementations etc., which might be more difficult when imposing restrictions on the implementation.

Indeed as @peterstangl is describing, pre-computing coefficients is definitely the best way to overcome the integration bottleneck, but it depends on what one wants to fit. You might be interested in having the analytical dependence on Wilson coefficients and lumping hadronic parameters into coefficients & uncertainties. But maybe you are also interested in fitting the hadronic parameters themselves (e.g. FF parameters from semileptonic decays). So I think whatever precomputation/caching strategy you use, it is most likely specific to the application and cannot be done in complete generality.

abhijitm08 commented 3 years ago

@DavidMStraub and @peterstangl : Thank you very much for the replies. We love Flavio and fully understand the motivation for the design choice during the development. Agreed that the generalisation is not a straight forward task. For us, the observables that we are interested to begin with are R(K), R(K), BF(Bs->mumu) and angular observables related to B->Kmumu, with WCs being the parameters of interest. For our study, we are conducting fits to quite a lot of toy measurements O(20k). Each fit takes between 20-45 minute. Adding more observables and parameters will only increase this. So we were brain storming on the possible speeds-ups that one could gain through Flavio (JIT compilation, auto-differentiation, caching of integrals, etc) and thought it was best to discuss with the maintainers on the ideas you had moving forward.

@peterstangl : The implementation you talk of would be pretty amazing indeed! What observables do you have this implementation for? Any of the above ones perhaps? We completely understand your other priorities, however if you have an example implementation for a certain observable and if it is somehow a matter of person power to extend this to the other ones, we could perhaps be of some help to you here (?). At the moment, I do not know of how much of a speed gain we would get from JIT compilation (@mayou36 and @dlanci ), but we can certainly investigate this with Flavio.

peterstangl commented 3 years ago

@peterstangl : The implementation you talk of would be pretty amazing indeed! What observables do you have this implementation for? Any of the above ones perhaps? We completely understand your other priorities, however if you have an example implementation for a certain observable and if it is somehow a matter of person power to extend this to the other ones, we could perhaps be of some help to you here (?). At the moment, I do not know of how much of a speed gain we would get from JIT compilation (@mayou36 and @dlanci ), but we can certainly investigate this with Flavio.

@abhijitm08 I have done the implementation in terms of a separate python package that provides everything needed to construct a likelihood in terms of second order polynomials in the Wilson coefficients. Using this package, I have constructed a likelihood containing all of the observables you mention above. One idea would be to make this package a submodule of flavio at some point. But currently the package is still under development in the context of unpublished work. That's also why I do not want to make it public yet. Anyway, I will think about how I could still help you with your issue and maybe provide you with parts of my implementation.

DavidMStraub commented 3 years ago

The ultimate speed-up would then be making @peterstangl's package compatible with JAX (which I suspect shouldn't bee too difficult since it's just polynomials) and then having a likelihood with gradient for gradient-based optimization or Hamiltonian Monte Carlo :sunglasses:

jonas-eschle commented 3 years ago

I've done some more benchmarking and it seems that with a few minor changes, we can gain some speedup with numba (10-20% estimated for the tested cases, with some changes to the code). So I am just looking here at the technical speedups in flavio, @peterstangl improvements by using the polynomials is independent and of course something else to examine.

and then having a likelihood with gradient for gradient-based optimization or Hamiltonian Monte Carlo

That would be nice and can work well in general (it is e.g. used in zfit or pyhf to speed up the minimization), but JAX is just difficult to use with the JIT, as we would also need to adjust python logic and therefore make it quite dependent on a package. For example any if-else logic needs to be jax, while numba can deal with it (but has no analytic gradients). And that would maybe need to be taken into account from the beginning on (but maybe it's easier to adjust, and the non-jitted should work without a lot of modifications). Also, the benefit of autograd is only there if everything is written in JAX, we don't really gain something from just half-way rewriting (or only jit speedup)

So my conclusions on this:

Do you think a few modifications to the code for an improved speed using numba (meaning to rewrite math heavy functions such as https://github.com/flav-io/flavio/blob/master/flavio/physics/bdecays/angular.py#L47) are welcome as PRs? This will not change anything in the building process or similar (not like cython would).

P.S: @DavidMStraub, very understandable this considerations were not taken into account in the beginning, but that's maybe what made flavio what it is now, and that's good, thanks a lot for all this work! Also, the more I inspected, the more cachings and optimizations I found, so the low hanging fruits are truly gone, I had a somewhat different impression in the beginning, an incorrect judgement.