Closed gillverd closed 4 years ago
If we adopt JAX, i believe we should place it in contrib
.
I agree about placing it in contrib.
Does computing the gradients come up a lot in NISQ machine learning stuff? It seems like an odd operation.
It comes up extremely often, in both VQE, QAOA, and QNN's. I would say the computation of the gradient is as crucial as the execution of the circuit itself, and is often (if not always) the bottleneck for the training of variational algorithms (main NISQ class of algorithms). Although there are other optimizers that are gradient-free, there are many reasons to stick to gradient-based optimization, as it is the most compatible with hybridization with classical neural networks.
@QuantumVerd Since we are not experienced with JAX, we will assign this to you.
So what you probably want is the ability to compute the derivative of a circuits, gates, or operations with respect to parameter of the gate?
Something like
jaxnp.grad(cirq.XPowGate(exponent))
or maybe
jaxnp.grad(cirq.unitary(cirq.XPowGate(exponent)))
?
Interestingly the later makes me wonder if cirq.unitary
for a parameterized unitary should return a function from the parameters to the unitary. In that case it would max sense for the returned object to be a jax numpy array, I think?
I don't think we have any plans to support jax and don't currently see a way to do this without a major overhead. Someone could try getting it working to see how much of our current numpy code works, but currently support for Cirq symbolic differentiation resides more in TensorFlow Quantum.
Going to close.
JAX is a Google-developed open source library that can take Numpy code and allow one to rapidly obtain gradients via autograd, it also has some nifty features allowing for execution on GPU's and TPU's.
It seems from this tutorial that switching to a JAX backend from Numpy can in some cases take only very minimal changes. It would be a very valuable feature to have in Cirq as it could greatly accelerate the calculations of gradients of circuits and allow for users to run circuits on GPU's and various hardware which support Google's XLA.