choderalab / modelforge

Infrastructure to implement and train NNPs
https://modelforge.readthedocs.io/en/latest/
MIT License
11 stars 4 forks source link

Strategy for long range interactions #226

Closed wiederm closed 2 weeks ago

wiederm commented 1 month ago

To calculate the energies of larger systems and, e.g., condensed phase systems, accurately capturing long-range interactions is critical for achieving realistic and reliable results. To enhance the modularity of our models and address these interactions effectively, we propose partitioning the total energy into three components:

  1. Short-range energy
  2. Long-range electrostatic energy
  3. Long-range (attractive) dispersion energy

Our current models already provide accurate calculations for the short-range energy (component 1). To incorporate the long-range electrostatic and dispersion energies (components 2 and 3), we need to implement several strategies. One approach to compute long-range electrostatic interactions is by calculating per-atom partial charges (as a function of the 3D coordinates). This is the approach we are focusing on subsequently

1. Outputting Per-Atom Partial Charges To compute long-range electrostatic interactions, we must first determine per-atom partial charges. This can be achieved by modifying our existing models to include multiple output heads. Each output head would generate a different property—one for per-atom energy (as is currently done) and another for per-atom partial charges. This requires adapting the models to support a user-specified number of output heads, ensuring flexibility across different simulation scenarios. PR #232 is addressing this issue.

2. Ensuring Charge Conservation The sum of per-atom partial charges must equal the total charge of the system. To maintain this conservation, we will implement a charge equilibration process. This process will scale the per-atom partial charges to ensure that they sum to the system's total charge, which is crucial for maintaining physical accuracy in simulations. PR #234 addresses this issue.

3. Calculating Electrostatic Energies Once the per-atom partial charges are obtained, we need to compute the electrostatic energy contribution to the total energy. Initially, we propose using a Coulomb potential with damping for short-range interactions. This approach can be extended by incorporating more sophisticated techniques, such as reaction fields. To do this efficiently we need to keep neighborlists with different cutoffs in memory. PR #235 addresses this issue.

The energy expression should be adaptable to allow for flexibility. Ideally, we want to encode the functional form of these energy calculations in the TOML configuration file, which governs the logic of the potential. This modular approach allows us and users to easily modify or extend the energy calculation methods without altering the core model.

4. Training the Additional Energy Components Integrating these new energy components into the models requires additional considerations for the training phase:

Regularization of Partial Charges: We need to add a term to the loss function that penalizes deviations from the expected total charge. This regularization ensures that the partial charges remain physically meaningful and consistent with the total charge of the system.

Training on Dipole Moments: In addition to regularizing the charges, we will train the model to reproduce the dipole moment of the system. This approach ensures that the partial charges accurately reflect the molecular electrostatics.

For now, we will calculate dispersion corrections using the implementation of DFD3, which is available for both PyTorch and JAX. This is not a long term solution, but I suggest that we first focus on electrostatics with D3 (since it is trivial to use) and then revisit dispersion.

As we will start to work on this we will link the PRs to each of the outline tasks. I will further document the general approach here.

chrisiacovella commented 1 month ago

Since the dispersion correction does not change during training, as coordinates do not change, it would likely be most efficient to pre-compute this (e.g., after we calculate the pair list). It should be easy to make this modular as well, to use LJ 1/6 parameters in place of D3. Longer term, we want to be able to train cross interaction terms for an LJ approach, which would require a change to the implementation (since this wouldn't be pre-computed).

I've been looking into the code on how to do the multiple neighbor lists. I think for clarity each cutoff we define should have a clear name, rather than having a list of cutoffs. So in addition to maximum_interaction_radius that we currently have (for models with a single interaction), we can also have something like maximum_charge_interaction_radius and maximum_dispersion_interaction_radius. The dispersion interaction radius would be used in the dispersion calculation above; we can enforce in the pydantic model that these need to be identical to match the original implementation.

In the code, we precompute all the possible pairs during initialization, then use this information within the base model class to calculate interacting pairs based on a given cutoff (this returns an instance of PairListOutputs that contains, rij and dij as well). I think adapting the function that calculates interacting pairs to take multiple cutoffs would be the most efficient, since we calculate the distance between all possible pairs here, then just select those within range. So it should be trivial to have this return a list containing an instance(s) of PairListOutputs. Here I think it is fine to pass a list of cutoffs, where the "masking" operation just loops over this list.

How we deal with the data that comes out of this operation would be in the _model_specific_input_preparation function in each of the potentials. For a potential that will require multiple cutoffs, we'll need to add additional variables in the child of the NeuralNetworkData class we define for each potential. For example, in addition to pair_indices we'd need something like pair_indices_for_electrostatics; d_ij and d_ij_for_electrostatics, etc. where these would all be set based on referencing into the list containing various PairListOutputs.

wiederm commented 4 weeks ago

+1 on the D3 precalculation! This is a great idea; we can cache this in the first epoch or precalculate during neighbor list calculation!

wiederm commented 4 weeks ago

I agree that clear names for the cutoff are better than having a list of cutoffs. In the near future we will have a maximum of three cutoffs.

wiederm commented 4 weeks ago

I am a bit worried about the memory footprint of multiple neighborlists. For a training set like SPICE2 we could run into memory issues if we kept the current batch size, but that's something we will have to empirically test.

I think one initial design decision that might help for the future is to consider the largest cutoff and build the neighborlist. All the other neighborlists than only have an index (or a mask) for the neighborlist that contains all small cutoff neighborlists.

chrisiacovella commented 4 weeks ago

I'm not sure we will have a memory footprint issue here (the memory footprint issues we had before were stemming from enumerate all pairs of atoms in a batch, even from separate molecules, before masking).

I was going to work on implementing this today and can basically just set up some test cases to see differences in memory footprint. appending masks to the PairListOutput data class for different cutoffs could also work, but would require a bit of refactoring of the code to have an additional intermediate step (which should not be hard).

I think the most important thing is going to be benchmarking of this! I was going to create a dummy NNP (e.g., just basically copying schnet or something, but allowing multiple cutoffs to be defined, even if they dont' do anything really in particular) to allow side by side comparisons.

wiederm commented 4 weeks ago

Agreed, maybe we can start with the implementation you had in mind and then benchmark and see if we need to optimize!

chrisiacovella commented 3 weeks ago

I am a bit worried about the memory footprint of multiple neighborlists. For a training set like SPICE2 we could run into memory issues if we kept the current batch size, but that's something we will have to empirically test.

I've been doing some testing as a function of cutoff. The time to train increases with cutoff, as you'd expect, but there is a negligible difference in total memory allocation. Using spice2, with batch size of 512, and schnet, a cutoff of 5 angstroms allocates about 10 mb each time we call the function to calculate the interacting pairs, 20 mb with 10 angstroms, etc. So this should not be an issue.

chrisiacovella commented 3 weeks ago

238 Address the multiple cutoffs. As described in the PR, the neighbor list takes cutoffs as a dict, so in theory can handle any number (the key for the dict is intended to be the parameter name, e.g., "maximum_dispersion_interaction_radius"). The function returns a dict using the same input keys to differentiate the output based on each cutoff. This allows the distance between particle pairs to only need to be calculated once regardless of the number of cutoffs and then these values appropriated masked using the desired. Since memory footprint was not an issue, this was the most efficient approach. The base model itself, was modified to allow for 2 additional cutoff parameters (one for dispersion, one for coulomb), which, if provided, are automatically parsed into the dictionary for input to the neighbor list.