google / qhbm-library

Quantum Hamiltonian-Based Models built on TensorFlow Quantum
https://qhbm-library.readthedocs.io/en/latest/
Apache License 2.0
40 stars 15 forks source link

Make VQT tf.function traceable #90

Closed zaqqwerty closed 3 years ago

zaqqwerty commented 3 years ago

Make VQT tf.function traceable.

Added a test where vqt is traced, to ensure tracing is possible and that it yields the correct values and gradients. Turns out we need to zero out the second return of the custom gradient when tracing, else the gradients are twice the correct value. I didn't use the "run in graph mode" decorator on the test since it seems to make everything in the test run in graph mode, but we have a lot of python stuff in either case. Adding a tf.function explicitly seemed like a more direct test. Ref #78

I also bumped the working version number.