grimme-lab / dxtb

Efficient And Fully Differentiable Extended Tight-Binding
https://dxtb.readthedocs.io
Apache License 2.0
71 stars 11 forks source link

Logic behind new interface #144

Closed hoelzerC closed 6 months ago

hoelzerC commented 7 months ago

Arguably the main functionality of dxtb will be to calculate

  1. Energies
  2. Gradients
  3. (Higher derivatives)

To better grasp the design choices behind the new interface (so far docs are not online), I am wondering about the design choices. What is the intended way to calculate energies and forces?

calc = Calculator(numbers, parametrisation, opts)

results = calc.singlepoint(numbers, positions, charges)

gradient = -calc.forces_analytical(numbers, positions, charges).detach()

Is that how it should be done? Ideally, a new user will intuitively do this right. Imo in the upper approach calc.forces_analytical should not require a -1 , as per definition F = - nabla(U). Also, I am wondering about the content of the results object the keys gradient and total_grad for instance are not relevant anymore, or are they? And which key unambiguously returns the (total) energy as required by most users?

Suggestions

To me, this feels like a more natural approach. Happy to discuss.

marvinfriede commented 7 months ago

I had some stuff written in the preliminary docs (here) and an example (here)

In short, forces are already negative, the AD forces are requested with forces, analytical forces are requested with forces_analytical, and numerical forces with forces_numerical.

marvinfriede commented 6 months ago

Now works with getters as in ASE (documented in "Getting started"). With caching, the energy calculation does not trigger an additional calculation after the forces calculations, as indicated by the counter of calculations ran.

import torch
import dxtb

dd = {"dtype": torch.double, "device": torch.device("cpu")}

numbers = torch.tensor([3, 1], device=dd["device"])
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd)
positions.requires_grad_(True)

calc = dxtb.calculators.GFN1Calculator(numbers, opts={"cache_enabled": True}, **dd)

forces = calc.get_forces(positions)
print(calc._ncalcs)
energy = calc.get_energy(positions)
print(calc._ncalcs)