PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.29k stars 588 forks source link

[BUG] Loading hybrid quantum neuronal networks does not work with TensorFlow #4999

Closed HeyPhiS closed 8 months ago

HeyPhiS commented 8 months ago

Expected behavior

When using the tensorflow function model.save und tf.keras.models.load_model it is expected that the model load all layers and its ways and recreates the saved model to include all weights etc.

Actual behavior

When using model.save and tf.keras.models.load_model with a hybrid model including both classical layers and qml.qnn.KerasLayerl the network including the quantum layer is reloaded without error, however the weights of the quantum layer are not. This issue exists independent of the order of quantum and classical layers.

Additional information

This issue was discussed in the pennylane support threat https://discuss.pennylane.ai/t/error-reloading-circuit-from-qasm-string/3679 This threat already gave rise to issue https://github.com/PennyLaneAI/pennylane/issues/4856 to which this issue is possibly related

Source code

import tensorflow as tf
import pennylane as qml
import numpy as np

dev = qml.device("default.qubit")
n_qubits = 2

def create_model():
    @qml.qnode(dev)
    def circuit(inputs, weights):
        qml.AngleEmbedding(inputs, wires=range(n_qubits))
        qml.RX(weights[0], 0)
        qml.RX(weights[1], 1)
        return [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))]

    weight_shapes = {"weights": (n_qubits,)}
    quantum_layer = qml.qnn.KerasLayer(circuit, weight_shapes, output_dim=2)

    model = tf.keras.models.Sequential()
    model.add(quantum_layer)
    model.add(tf.keras.layers.Dense(2, activation=tf.nn.softmax, input_shape=(2,)))

    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

model = create_model()

num_points = 5
dummy_input_data = np.random.uniform(0, np.pi, size=(num_points, 2))
dummy_output_data = np.random.randint(2, size=(num_points, 2))

model.fit(dummy_input_data, dummy_output_data, epochs=1, batch_size=0)

model.save("model")

loaded_model = tf.keras.models.load_model("model")

print("")
print("saved weights:", model.layers[0].weights)
print("")
print("Loaded weights:", loaded_model.layers[0].weights)

Tracebacks

WARNING:tensorflow:You are casting an input of type complex128 to an incompatible dtype float32.  This will discard the imaginary part and may not be what you intended.
WARNING:tensorflow:You are casting an input of type complex128 to an incompatible dtype float32.  This will discard the imaginary part and may not be what you intended.
WARNING:tensorflow:You are casting an input of type complex128 to an incompatible dtype float32.  This will discard the imaginary part and may not be what you intended.
WARNING:tensorflow:You are casting an input of type complex128 to an incompatible dtype float32.  This will discard the imaginary part and may not be what you intended.
1/1 [==============================] - 0s 349ms/step - loss: 0.5533 - accuracy: 0.8000
WARNING:tensorflow:AutoGraph could not transform <function validate_device_wires at 0x0000027642280F70> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'NoneType' object has no attribute '_fields'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function validate_device_wires at 0x0000027642280F70> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'NoneType' object has no attribute '_fields'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function _gcd_import at 0x000002762FE83490> and will run it as-is.
Cause: Unable to locate the source code of <function _gcd_import at 0x000002762FE83490>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function _gcd_import at 0x000002762FE83490> and will run it as-is.
Cause: Unable to locate the source code of <function _gcd_import at 0x000002762FE83490>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
INFO:tensorflow:Assets written to: model\assets
INFO:tensorflow:Assets written to: model\assets

saved weights: [<tf.Variable 'weights:0' shape=(2,) dtype=float32, numpy=array([ 0.49809802, -0.9938641 ], dtype=float32)>]

Loaded weights: []

System information

Platform info:           Windows-10-10.0.19045-SP0
Python version:          3.10.12
Numpy version:           1.26.0
Scipy version:           1.11.2
Installed devices:
- default.gaussian (PennyLane-0.33.1)
- default.mixed (PennyLane-0.33.1)
- default.qubit (PennyLane-0.33.1)
- default.qubit.autograd (PennyLane-0.33.1)
- default.qubit.jax (PennyLane-0.33.1)
- default.qubit.legacy (PennyLane-0.33.1)
- default.qubit.tf (PennyLane-0.33.1)
- default.qubit.torch (PennyLane-0.33.1)
- default.qutrit (PennyLane-0.33.1)
- null.qubit (PennyLane-0.33.1)
- lightning.qubit (PennyLane-Lightning-0.33.1)
- strawberryfields.fock (PennyLane-SF-0.29.0)
- strawberryfields.gaussian (PennyLane-SF-0.29.0)
- strawberryfields.gbs (PennyLane-SF-0.29.0)
- strawberryfields.remote (PennyLane-SF-0.29.0)
- strawberryfields.tf (PennyLane-SF-0.29.0)
- qiskit.aer (PennyLane-qiskit-0.32.0)
- qiskit.basicaer (PennyLane-qiskit-0.32.0)
- qiskit.ibmq (PennyLane-qiskit-0.32.0)
- qiskit.ibmq.circuit_runner (PennyLane-qiskit-0.32.0)
- qiskit.ibmq.sampler (PennyLane-qiskit-0.32.0)
- qiskit.remote (PennyLane-qiskit-0.32.0)

The issue is also visible in the Tracebacks displayed in https://github.com/PennyLaneAI/pennylane/issues/4856 and therefore also reproducible the system configuration displayed there.

Existing GitHub issues

timmysilv commented 8 months ago

Hi @HeyPhiS, thanks for reporting this! I did some digging on this today, and I believe I've root-caused the issue and proposed a fix. Once it gets merged in, we'll let you know here.

timmysilv commented 8 months ago

the fix has been merged to a feature branch, and will be released in PennyLane v0.34, coming next week! Keep an eye out for that release, and definitely let me know if anything else comes up

HeyPhiS commented 8 months ago

Great! Thank you very much very so quickly looking into this 🥇