quantumlib / Cirq

A python framework for creating, editing, and invoking Noisy Intermediate Scale Quantum (NISQ) circuits.
Apache License 2.0
4.2k stars 996 forks source link

Support for a JAX backend #1628

Closed gillverd closed 4 years ago

gillverd commented 5 years ago

JAX is a Google-developed open source library that can take Numpy code and allow one to rapidly obtain gradients via autograd, it also has some nifty features allowing for execution on GPU's and TPU's.

It seems from this tutorial that switching to a JAX backend from Numpy can in some cases take only very minimal changes. It would be a very valuable feature to have in Cirq as it could greatly accelerate the calculations of gradients of circuits and allow for users to run circuits on GPU's and various hardware which support Google's XLA.

vtomole commented 5 years ago

If we adopt JAX, i believe we should place it in contrib.

Strilanc commented 5 years ago

I agree about placing it in contrib.

Does computing the gradients come up a lot in NISQ machine learning stuff? It seems like an odd operation.

gillverd commented 5 years ago

It comes up extremely often, in both VQE, QAOA, and QNN's. I would say the computation of the gradient is as crucial as the execution of the circuit itself, and is often (if not always) the bottleneck for the training of variational algorithms (main NISQ class of algorithms). Although there are other optimizers that are gradient-free, there are many reasons to stick to gradient-based optimization, as it is the most compatible with hybridization with classical neural networks.

vtomole commented 5 years ago

@QuantumVerd Since we are not experienced with JAX, we will assign this to you.

dabacon commented 5 years ago

So what you probably want is the ability to compute the derivative of a circuits, gates, or operations with respect to parameter of the gate?

Something like

jaxnp.grad(cirq.XPowGate(exponent))

or maybe

jaxnp.grad(cirq.unitary(cirq.XPowGate(exponent)))

?

Interestingly the later makes me wonder if cirq.unitary for a parameterized unitary should return a function from the parameters to the unitary. In that case it would max sense for the returned object to be a jax numpy array, I think?

dabacon commented 4 years ago

I don't think we have any plans to support jax and don't currently see a way to do this without a major overhead. Someone could try getting it working to see how much of our current numpy code works, but currently support for Cirq symbolic differentiation resides more in TensorFlow Quantum.

Going to close.