LHCfitNikhef / smefit_release

SMEFiT a Standard Model Effective Field Theory fitter
GNU General Public License v3.0
6 stars 1 forks source link

Upgrade to jax #77

Closed LucaMantani closed 5 months ago

LucaMantani commented 6 months ago

This PR aims to upgrade the NS routine, using jax to accelerate sampling and introducing compatibility with GPUs.

In addition to the modifications introduced in the PR, the external_chi2 modules need to be updated. The only requirement seems to be that the function compute_chi2 uses jax.numpy.linalg.multi_dot instead of the numpy version.

I am also introducing the possibility to use vectorized=True in ultranest fits.

TODO:

LucaMantani commented 6 months ago

I ran a test on a single cpu against main, obtaining compatible results but much faster, a factor 10 improvement!

Main:

[23:06:41] Time : 51.227 minutes                                                               ultranest.py:286
           Number of samples: 21974                                                            ultranest.py:287
┏━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Parameter ┃ Best value ┃ Error ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
│ O81qq     │ -0.101     │ 0.122 │
│ O83qq     │ 0.020      │ 0.149 │
│ O8dt      │ -0.145     │ 0.233 │
│ O8qd      │ -0.296     │ 0.368 │
│ O8qt      │ -0.207     │ 0.176 │
│ O8qu      │ -0.177     │ 0.246 │
│ O8ut      │ -0.078     │ 0.166 │
│ OWWW      │ -0.000     │ 0.003 │
│ OpD       │ -0.001     │ 0.018 │
│ OpG       │ -0.005     │ 0.005 │
│ OpWB      │ 0.004      │ 0.009 │
│ Ope       │ 0.001      │ 0.004 │
└───────────┴────────────┴───────┘

PR:

[22:13:57] Time : 5.327 minutes                                                                ultranest.py:306
           Number of samples: 21767                                                            ultranest.py:307
┏━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Parameter ┃ Best value ┃ Error ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
│ O81qq     │ -0.099     │ 0.122 │
│ O83qq     │ 0.020      │ 0.151 │
│ O8dt      │ -0.145     │ 0.233 │
│ O8qd      │ -0.294     │ 0.371 │
│ O8qt      │ -0.205     │ 0.175 │
│ O8qu      │ -0.172     │ 0.246 │
│ O8ut      │ -0.076     │ 0.164 │
│ OWWW      │ -0.000     │ 0.003 │
│ OpD       │ -0.000     │ 0.019 │
│ OpG       │ -0.005     │ 0.005 │
│ OpWB      │ 0.004      │ 0.009 │
│ Ope       │ 0.001      │ 0.004 │
└───────────┴────────────┴───────┘

The runcard used for the test is in attachment. test_jax.txt

juanrojochacon commented 6 months ago

Very impressive @LucaMantani ! Seems certainly a feature we want to exploit in future fits, specially useful if we extend the operator basis

LucaMantani commented 6 months ago

Now also constraints are implemented, following more or less what is done in the main code. Benchmark is successful.

report_test.pdf

LucaMantani commented 6 months ago

I performed a benchmark in terms timing likelihood calls, which is basically the only thing that is truly optimised with JAX.

I used the likelihood of the FCC fit, so I think it should be the biggest one we ever used, with 50 Wilson coefficients, HL_LHC projections, FCC projections + external likelihoods.

JAX CPU: Time for 50k calls to loglikelihood: 14.532 seconds
JAX GPU: Time for 50k calls to loglikelihood: 10.857 seconds
MAIN CPU: Time for 50k calls to loglikelihood: 96.026 seconds

So, as you can see there roughly a factor ~7-10. I expected a bit more from the GPU, but maybe there is some further optimisation that can be done, not sure.

However, I'd say this is encouraging and should already allow to scale up the fits and obtain results in more than acceptable times 😊

juanrojochacon commented 6 months ago

Thanks @LucaMantani looks great. I understand that results are completely identical? Is there any reason not to include this by default @tgiani @giacomomagni @jacoterh ? If so, we should probably make this option a default also when running on CPUs?

jacoterh commented 6 months ago

Very nice Luca! I agree we should definitely make this the default option in the code. I'd say me and @giacomomagni review the various parts of the code, especially the part of the constraints, but I don't see any major showstoppers!

giacomomagni commented 6 months ago

Thanks for the work @LucaMantani, it looks really a good improvement!

I've left some comments to see if possible avoid code duplication.