tensorflow / quantum

An open-source Python framework for hybrid quantum-classical machine learning.
https://www.tensorflow.org/quantum
Apache License 2.0
1.81k stars 580 forks source link

Higher order gradient of tfq layers #285

Open refraction-ray opened 4 years ago

refraction-ray commented 4 years ago

Both first order derivative of tfq layers and higher order gradient of non tfq layers work with GradientTape as follows:

x = tf.Variable(initial_value=0.2)
with tf.GradientTape() as t:
    with tf.GradientTape() as t2:
        y = x**3
        g = t2.gradient(y, x)
        print(g)
    g2 = t.gradient(g, x)
print(g2)
##########
a = sy.Symbol("a")
c = cirq.Circuit()
c.append(cirq.rx(a)(cirq.GridQubit(0,0)))
model = tfq.layers.PQC(c, operators=[cirq.Z(cirq.GridQubit(0,0))])
with tf.GradientTape() as t:
    o = model(tfq.convert_to_tensor([cirq.Circuit()]))[0,0]
    g = t.gradient(o, model.variables)
print(g)

However the following doesn't work, with tensorflow error as InvalidArgumentError: Operation 'cond' has no attr named '_XlaCompile'. ValueError: Insufficient elements in branch_graphs[0].outputs. Expected: 6 Actual: 5

a = sy.Symbol("a")
c = cirq.Circuit()
c.append(cirq.rx(a)(cirq.GridQubit(0,0)))
model = tfq.layers.PQC(c, operators=[cirq.Z(cirq.GridQubit(0,0))])
with tf.GradientTape() as t:
    with tf.GradientTape() as t2:
        o = model(tfq.convert_to_tensor([cirq.Circuit()]))[0,0]
        g = t2.gradient(o, model.variables)
        print(g)
    g2 = t.gradient(g, model.variables)
print(g2)

Is my code wrong or tfq layers have special issues in terms of higher order gradients due to the way AD is implemented in these layers?

jaeyoo commented 4 years ago

https://arxiv.org/pdf/1901.05374.pdf

Please see the section 3.4

It is not implemented in TFQ yet.

github-actions[bot] commented 4 years ago

This issue has not had any activity in a month. Is it stale ?

mhucka commented 1 day ago

@refraction-ray @jaeyoo For purposes of planning work and doing repository housekeeping, could you let us know what the status of this is?