PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.18k stars 569 forks source link

Propagate the gradient to part of an array only #5767

Open minhtriet opened 1 month ago

minhtriet commented 1 month ago

Feature details

I want to optimize for the coordinates of different molecules in a reaction A+B->AB. It would make sense if I fixed coordinates of A and optimize for B's only. However, it doesn't work right now.

Suppose I have this JAX traced array

>>> coords
Traced<ConcreteArray([1. 1. 1.], dtype=float32)>with<JVPTrace(level=2/0)>
...

Then jnp.array([0, 0, 0, *coord]) won't work (within a pennylane context). The error is jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[2,3].

It is because at \qchem\openfermion_obs.py, we have this line geometry_dhf = qml.numpy.array(coordinates.reshape(len(symbols), 3)). At the end of the stack trace it would convert jax to a np array

Implementation

Here is my MVP to recreate the error. The issue is at line 17

import pennylane as qml
from pennylane import numpy as np
import jax
import optax

dev = qml.device("default.qubit", 4)

@qml.qnode(dev)
def circuit_expected(H):
    qml.BasisState([1, 1, 0, 0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(0.2, wires=[0, 1, 2, 3])
    return qml.expval(H)

def loss_f(coord):
    symbols = ["H", "H"]
    H, qb = qml.qchem.molecular_hamiltonian(symbols, jax.numpy.array([0, 0, 0, *coord]))
    return circuit_expected(H)

H_1 = jax.numpy.array([1., 1., 1.])
opt = optax.sgd(learning_rate=0.4)
opt_coords_state = opt.init(H_1)

for i in range(10):
    grad_coordinates = jax.grad(loss_f, 0)(H_1)
    updates, opt_coords_state = opt.update(grad_coordinates, opt_coords_state)
    H_1 = optax.apply_updates(H_1, updates)
    print(grad_coordinates)

How important would you say this feature is?

2: Somewhat important. Needed this quarter.

Additional information

No response

CatalinaAlbornoz commented 1 month ago

Hi @minhtriet,

As I answered in the forum thread it looks like molecular_hamiltonian unfortunately only works with autograd so you would need to go back to using PennyLane Numpy instead of Jax :crying_cat_face: .

We’ll look into adding a warning in the documentation about this and hopefully allowing it to work with Jax in future releases.

soranjh commented 3 weeks ago

Thanks @minhtriet for opening the issue. As suggested by Catalina here, could you please try using the diff_hamiltonian function which works with different frameworks? You might also look at this demo for more insight on differentiable qchem workflows.