FHof / torchquad

Multidimensional numerical integration on the GPU using PyTorch
https://www.esa.int/gsp/ACT/open_source/torchquad/
GNU General Public License v3.0
0 stars 0 forks source link

Support gradient calculation and JIT compilation #17

Closed FHof closed 2 years ago

FHof commented 2 years ago

I replaced len() with .shape[0] and changed the integration_domain argument validity check so that tf.Variable can be used as integration_domain and the boundary value checks are skipped if the values are not concrete, which can happen when compiling the function. JIT compilation and gradient calculation works with JAX, Torch and Tensorflow.