N3PDF / vegasflow

VegasFlow: accelerating Monte Carlo simulation across multiple hardware platforms
https://vegasflow.readthedocs.io
Apache License 2.0
34 stars 9 forks source link

Differentiability interface #79

Closed scarlehoff closed 2 years ago

scarlehoff commented 3 years ago

This should address #76

The idea (subject to change as a I write documentation and find out I've done something wrong) is to have a method make_differentiable() that ensures that the integration can be run inside a tf.function decorator. For now that part is working (and can be differentiated). This PR is missing:

Anyway, below an example:

from vegasflow import VegasFlow, float_me
import tensorflow as tf

dims = 4
n_calls = int(1e4)
vegas_instance = VegasFlow(dims, n_calls, verbose=False)
z = tf.Variable(float_me(1.0))

def example_integrand(x, **kwargs):
    y = tf.reduce_sum(x, axis=1)
    return y*z

runner = vegas_instance.make_differentiable()
vegas_instance.compile(example_integrand)

# Now we run a few iterations to train the grid, but we can bin them
_ = vegas_instance.run_integration(3)

@tf.function
def some_complicated_function(x):
    integration_result, error, _ = runner()
    return x*integration_result

my_x = float_me(4.0)
result = some_complicated_function(my_x)

def compute_and_print_gradient():
    with tf.GradientTape() as tape:
        tape.watch(my_x)
        y = some_complicated_function(my_x)

    grad = tape.gradient(y, my_x)
    print(f"Result {y.numpy():.3}, gradient: {grad.numpy():.3}")

compute_and_print_gradient()
z.assign(float_me(4.0))
compute_and_print_gradient()

What "make differentiable" actually entails is making sure the integration is running in one single device. I guess forfeiting control of the devices to TensorFlow would work just the same but as we tested in the past, the distribution strategies of TensorFlow do not work very well for our purposes and that's the reason we are managing it ourselves. And, if at some point there's one strategy that works it would be as easy as emptying the list of devices (because in that case it forfeits all control to TensorFlow).

There is a second option (to keep the multi-device capabilities working in this situation) which is wrapping the management in this case in a py_function but then VegasFlow would be managing a threadpool inside a Tensorflow pool of threads over which it has no control and would probably ending up summoning creatures from down below. I'm open to suggestions though, maybe there's an easy way of doing this I'm missing.