[BUG] Loading hybrid quantum neuronal networks does not work with TensorFlow #4999

Closed HeyPhiS closed 8 months ago

HeyPhiS commented 8 months ago

Expected behavior

When using the tensorflow function und tf.keras.models.load_model it is expected that the model load all layers and its ways and recreates the saved model to include all weights etc.

Actual behavior

When using and tf.keras.models.load_model with a hybrid model including both classical layers and qml.qnn.KerasLayerl the network including the quantum layer is reloaded without error, however the weights of the quantum layer are not. This issue exists independent of the order of quantum and classical layers.

Additional information

This issue was discussed in the pennylane support threat This threat already gave rise to issue to which this issue is possibly related

Source code

import tensorflow as tf
import pennylane as qml
import numpy as np

dev = qml.device("default.qubit")
n_qubits = 2

def create_model():
    def circuit(inputs, weights):
        qml.AngleEmbedding(inputs, wires=range(n_qubits))
        qml.RX(weights[0], 0)
        qml.RX(weights[1], 1)
        return [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(1))]

    weight_shapes = {"weights": (n_qubits,)}
    quantum_layer = qml.qnn.KerasLayer(circuit, weight_shapes, output_dim=2)

    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Dense(2, activation=tf.nn.softmax, input_shape=(2,)))

    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

model = create_model()

num_points = 5
dummy_input_data = np.random.uniform(0, np.pi, size=(num_points, 2))
dummy_output_data = np.random.randint(2, size=(num_points, 2)), dummy_output_data, epochs=1, batch_size=0)"model")

loaded_model = tf.keras.models.load_model("model")

print("saved weights:", model.layers[0].weights)
print("Loaded weights:", loaded_model.layers[0].weights)


INFO:tensorflow:Assets written to: model\assets
INFO:tensorflow:Assets written to: model\assets

saved weights: [<tf.Variable 'weights:0' shape=(2,) dtype=float32, numpy=array([ 0.49809802, -0.9938641 ], dtype=float32)>]

Loaded weights: []

System information

The issue is also visible in the Tracebacks displayed in and therefore also reproducible the system configuration displayed there.

timmysilv commented 8 months ago

Hi @HeyPhiS, thanks for reporting this! I did some digging on this today, and I believe I've root-caused the issue and proposed a fix. Once it gets merged in, we'll let you know here.

timmysilv commented 8 months ago

the fix has been merged to a feature branch, and will be released in PennyLane v0.34, coming next week! Keep an eye out for that release, and definitely let me know if anything else comes up

HeyPhiS commented 8 months ago

Great! Thank you very much very so quickly looking into this 🥇