valence-labs / mess

MESS: Modern Electronic Structure Simulations
https://valence-labs.github.io/mess/
MIT License
20 stars 2 forks source link

Autograd w.r.t. nuclei positions #22

Open Binbose opened 4 weeks ago

Binbose commented 4 weeks ago

Hey, I was wondering if mess is able to calculate forces on atoms with jax's autodiff? I tried it like this:

def struc_to_energy(r):
    mol = Structure(
            atomic_number=jnp.array([8, 1, 1]),
            position=to_bohr(r),
        )
    basis = basisset(mol, "6-31g")
    H = Hamiltonian(basis, xc_method="pbe")
    E, C, sol = minimise(H)
    return E

struc_to_energy_grad = jax.grad(struc_to_energy)
r = jnp.array([
                    [0.0000, 0.0000, 0.1165],
                    [0.0000, 0.7694, -0.4661],
                    [0.0000, -0.7694, -0.4661],
                ])

but I am getting errors from np.array being called on 'r' in the Structure construction and I dont see a good way to get around this.

hatemhelal commented 3 weeks ago

We'd like to support this more generally but a few more pieces need to be implemented to make your example work:

Both of these are feasible but for the moment this is a work in progress. I suggest leaving this issue open as a reminder to check this works as the different pieces come together.