Open dominicpasquali opened 2 years ago
Hi @dominicpasquali it should work seamlessly as the JAX backend makes a PennyLane QNode just another function that you can stick anywhere in your code. You can see from the simple example below that a loss function taking the output of a neural network defined with FLAX can pass it to a PennyLane quantum circuit. You can then differentiate through the loss function as you would do for any other FLAX/JAX code.
Thanks for opening the issue. @josh146 and I discussed it and I would say that there is no specific need to write a FLAX layer for PennyLane. Unless you have a specific use-case/example that is not working directly with the JAX interface, I think there shouldn't be any problem.
import pennylane as qml
import jax
from jax import numpy as jnp
from typing import Sequence
import flax.linen as nn
# A quantum function using Pennylane
N = 2
dev = qml.device("default.qubit.jax", wires=N, shots=None) # This is good ol backprop
@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(weights:jnp.array)->float:
"""A random variational circuit ansatz.
Args:
weights (jnp.array): The variational parameters for the ansatz (circuit)
Returns:
float: The expectation value.
"""
qml.RandomLayers(weights, wires=range(N))
return qml.expval(qml.PauliZ(0))
key = jax.random.PRNGKey(42)
weights = jax.random.uniform(key, shape=(N, 3))
print(circuit(weights))
# A neural network using FLAX
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([2, 3])
batch = jnp.ones((2, 3))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)
# Pennylane + FLAX working together such that you can autodiff through the circuit
def loss(variables):
output = model.apply(variables, batch)
out = circuit(output)
return jnp.sum(out)
loss(variables)
print(jax.grad(loss)(variables))
Feature details
Would like to have Flax implementation/layer for PennyLane.
Implementation
This implementation would be similar to the TorchLayer implementation in PennyLane.
How important would you say this feature is?
2: Somewhat important. Needed this quarter.
Additional information
The Flax module class has striking striking similarities to the PyTorch module class.