blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Add some metadata to integrators and export coefficients #679

Closed reubenharry closed 4 months ago

reubenharry commented 4 months ago

Current behavior

There's no way to currently obtain the number of gradient calls for a given integrator.

Also, it is sometimes useful to have the integrator coefficients, and these aren't exported from integrators.py.

Desired behavior

Functions like:

def calls_per_integrator_step(c):
    if c==velocity_verlet_coefficients: return 1
    if c==mclachlan_coefficients: return 2
    if c==yoshida_coefficients: return 3

    else: raise Exception

def name_integrator(c):
    if c==velocity_verlet_coefficients: return "velocity_verlet"
    if c==mclachlan_coefficients: return "mclachlan"
    if c==yoshida_coefficients: return "yoshida"

    else: raise Exception

Also, to export the coefficients.

junpenglao commented 4 months ago

Per #681, we need some other ways to expose this information.

How about assigning the property to the return integrator object? We can type integrators as Protocal.

reubenharry commented 4 months ago

Yes, I think that would be good. The info we want attached is:

  1. Order of the integrator
  2. Number of grad calls
  3. Name
  4. Type (isokinetic or mclachlan)
  5. The integrator itself (this could be a function of "isokinetic" or "euclidean", if you think that's a reasonable design decision
junpenglao commented 4 months ago

I think the decision ultimately tie to how these information is used in the library beyond benchmarking. Let's circle back when you have some examples.