tensorflow / quantum

Hybrid Quantum-Classical Machine Learning in TensorFlow
https://www.tensorflow.org/quantum
Apache License 2.0
1.77k stars 560 forks source link

Getting the jacobian inside GradientTape #718

Open lockwo opened 1 year ago

lockwo commented 1 year ago

If I want to compute the jacobian inside of gradient tape (so I can use it downstream in optimization) classically I can do:

x = tf.random.normal([7, 1])
model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(1,)),
    tf.keras.layers.Dense(5),
])

with tf.GradientTape(persistent=True) as tape:
    y = model(x)
    j = tape.jacobian(y, model.trainable_variables[0])

But if I try to do that with a PQC in the mode, I get the following error: LookupError: No gradient defined for operation'TfqAdjointGradient' (op type: TfqAdjointGradient). In general every operation must have an associated `@tf.RegisterGradient` for correct autodiff, which this op is lacking. If you want to pretend this operation is a constant in your program, you may insert `tf.stop_gradient`. This can be useful to silence the error in cases where you know gradients are not needed, e.g. the forward pass of tf.custom_gradient. Please see more details in https://www.tensorflow.org/api_docs/python/tf/custom_gradient.

Example code:


x = tfq.convert_to_tensor([cirq.Circuit()] * 7)
qubits = cirq.GridQubit.rect(1, 2)
readouts = [cirq.Z(i) for i in qubits]
s = sympy.symbols("a b")
c = cirq.Circuit()
c += cirq.ry(s[0]).on(qubits[0])
c += cirq.ry(s[1]).on(qubits[1])

model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(), dtype=tf.string),
    tfq.layers.PQC(c, readouts),
])

with tf.GradientTape(persistent=True) as tape:
    y = model(x)
    j = tape.jacobian(y, model.trainable_variables[0])

If I put in a tf.stop_gradient(y) then I just get the jacobian to be none. Also I don't want to use a stop gradient because there are down stream tasks I want the gradient accumulating through for a final GD update. Is this is a known limitation or am I doing something wrong? In either case, what would your recommended path forward be?