PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.35k stars 603 forks source link

Flax Implementation #2642

Open dominicpasquali opened 2 years ago

dominicpasquali commented 2 years ago

Feature details

Would like to have Flax implementation/layer for PennyLane.

Implementation

This implementation would be similar to the TorchLayer implementation in PennyLane.

How important would you say this feature is?

2: Somewhat important. Needed this quarter.

Additional information

The Flax module class has striking striking similarities to the PyTorch module class.

quantshah commented 2 years ago

Hi @dominicpasquali it should work seamlessly as the JAX backend makes a PennyLane QNode just another function that you can stick anywhere in your code. You can see from the simple example below that a loss function taking the output of a neural network defined with FLAX can pass it to a PennyLane quantum circuit. You can then differentiate through the loss function as you would do for any other FLAX/JAX code.

Thanks for opening the issue. @josh146 and I discussed it and I would say that there is no specific need to write a FLAX layer for PennyLane. Unless you have a specific use-case/example that is not working directly with the JAX interface, I think there shouldn't be any problem.

import pennylane as qml

import jax
from jax import numpy as jnp

from typing import Sequence
import flax.linen as nn

# A quantum function using Pennylane
N = 2
dev = qml.device("default.qubit.jax", wires=N, shots=None) # This is good ol backprop

@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(weights:jnp.array)->float:
    """A random variational circuit ansatz.

    Args:
        weights (jnp.array): The variational parameters for the ansatz (circuit)

    Returns:
        float: The expectation value.
    """
    qml.RandomLayers(weights, wires=range(N))
    return qml.expval(qml.PauliZ(0))

key = jax.random.PRNGKey(42)
weights = jax.random.uniform(key, shape=(N, 3))
print(circuit(weights))

# A neural network using FLAX
class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([2, 3])
batch = jnp.ones((2, 3))
variables = model.init(jax.random.PRNGKey(0), batch)
output = model.apply(variables, batch)

# Pennylane + FLAX working together such that you can autodiff through the circuit
def loss(variables):
    output = model.apply(variables, batch)
    out = circuit(output)
    return jnp.sum(out)

loss(variables)
print(jax.grad(loss)(variables))