Closed LucaMantani closed 5 months ago
I ran a test on a single cpu against main, obtaining compatible results but much faster, a factor 10 improvement!
Main:
[23:06:41] Time : 51.227 minutes ultranest.py:286
Number of samples: 21974 ultranest.py:287
┏━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Parameter ┃ Best value ┃ Error ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
│ O81qq │ -0.101 │ 0.122 │
│ O83qq │ 0.020 │ 0.149 │
│ O8dt │ -0.145 │ 0.233 │
│ O8qd │ -0.296 │ 0.368 │
│ O8qt │ -0.207 │ 0.176 │
│ O8qu │ -0.177 │ 0.246 │
│ O8ut │ -0.078 │ 0.166 │
│ OWWW │ -0.000 │ 0.003 │
│ OpD │ -0.001 │ 0.018 │
│ OpG │ -0.005 │ 0.005 │
│ OpWB │ 0.004 │ 0.009 │
│ Ope │ 0.001 │ 0.004 │
└───────────┴────────────┴───────┘
PR:
[22:13:57] Time : 5.327 minutes ultranest.py:306
Number of samples: 21767 ultranest.py:307
┏━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓
┃ Parameter ┃ Best value ┃ Error ┃
┡━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩
│ O81qq │ -0.099 │ 0.122 │
│ O83qq │ 0.020 │ 0.151 │
│ O8dt │ -0.145 │ 0.233 │
│ O8qd │ -0.294 │ 0.371 │
│ O8qt │ -0.205 │ 0.175 │
│ O8qu │ -0.172 │ 0.246 │
│ O8ut │ -0.076 │ 0.164 │
│ OWWW │ -0.000 │ 0.003 │
│ OpD │ -0.000 │ 0.019 │
│ OpG │ -0.005 │ 0.005 │
│ OpWB │ 0.004 │ 0.009 │
│ Ope │ 0.001 │ 0.004 │
└───────────┴────────────┴───────┘
The runcard used for the test is in attachment. test_jax.txt
Very impressive @LucaMantani ! Seems certainly a feature we want to exploit in future fits, specially useful if we extend the operator basis
Now also constraints are implemented, following more or less what is done in the main code. Benchmark is successful.
I performed a benchmark in terms timing likelihood calls, which is basically the only thing that is truly optimised with JAX.
I used the likelihood of the FCC fit, so I think it should be the biggest one we ever used, with 50 Wilson coefficients, HL_LHC projections, FCC projections + external likelihoods.
JAX CPU: Time for 50k calls to loglikelihood: 14.532 seconds
JAX GPU: Time for 50k calls to loglikelihood: 10.857 seconds
MAIN CPU: Time for 50k calls to loglikelihood: 96.026 seconds
So, as you can see there roughly a factor ~7-10. I expected a bit more from the GPU, but maybe there is some further optimisation that can be done, not sure.
However, I'd say this is encouraging and should already allow to scale up the fits and obtain results in more than acceptable times 😊
Thanks @LucaMantani looks great. I understand that results are completely identical? Is there any reason not to include this by default @tgiani @giacomomagni @jacoterh ? If so, we should probably make this option a default also when running on CPUs?
Very nice Luca! I agree we should definitely make this the default option in the code. I'd say me and @giacomomagni review the various parts of the code, especially the part of the constraints, but I don't see any major showstoppers!
Thanks for the work @LucaMantani, it looks really a good improvement!
I've left some comments to see if possible avoid code duplication.
This PR aims to upgrade the NS routine, using jax to accelerate sampling and introducing compatibility with GPUs.
In addition to the modifications introduced in the PR, the external_chi2 modules need to be updated. The only requirement seems to be that the function compute_chi2 uses
jax.numpy.linalg.multi_dot
instead of the numpy version.I am also introducing the possibility to use
vectorized=True
in ultranest fits.TODO: