openmm / openmm-ml

High level API for using machine learning models in OpenMM simulations
Other
75 stars 25 forks source link

Add support for NequIP models #60

Open sef43 opened 9 months ago

sef43 commented 9 months ago

This PR adds in support for NequIP models to openmm-ml. There are no pre-trained models available but the model framework is well defined. This will allow users to use their own trained NequIP models in OpenMM simulations.

Also adds code to compute neighbor lists with pytorch that will be used for MACE models too. (NNPOps neighbor list can be added later)

Addresses #48 and see https://github.com/mir-group/nequip/issues/288 for further discussion.

TODO: Need to add testing but not sure how to do this cleanly in CI considering NequIP needs to be installed via pip

jchodera commented 7 months ago

Can we train a Nequip model on SPICE and enable that to be usable through openmm-ml?

svarner9 commented 2 months ago

Hello,

Has there been any further progress on this? I have used NequIP in LAMMPS but would like to instead use OpenMM because it is more compatible with the enhanced sampling packages that I use.

I have tried running simulations with a NequIP potential with openmm-ml in its current state, however the speed is significantly slower than in LAMMPS. Both simulations are run on a single GPU, however in LAMMPS I also use 32 cpu threads and kokkos.

I am not sure if I am doing something incorrect in running openmm-ml, but currently it is unusable for my rather simple system of 645 atoms. Is it expected for it to be slow on a system of this size in its current state?

I can provide further information if needed. Thank you so much in advance!

Best, Sam

JMorado commented 2 months ago

@svarner9, could you try the current implementation available here? It uses the NNPOps neighbor list, so I anticipate it might be slightly faster for a system of the size you're working with. You can create the MLPotential using something along these lines:

potential = MLPotential('nequip', modelPath='model.pth', lengthScale=0.1, energyScale=4.184)

What speed-up did you observe in your LAMMPS simulations compared to OpenMM/OpenMM-ML?

JMorado commented 2 months ago

Just for the records, I'm posting here a comparison between the energies I get when using this OpenMM-ML interface and the ASE-like NequIPCalculator. The script I'm using to calculate the energies is the following:

import openmm as mm
import openmm.app as app
import openmm.unit as unit
from ase import Atoms
from nequip.ase.nequip_calculator import NequIPCalculator
from openmmml import MLPotential

lengthScale = 0.1  # Angstrom to nm
energyScale = 96.4853075  # eV to kJ/mol
model = "si-deployed.pth"
pdb_file = "si.pdb"

# Calculate the energy using a NequIPCalculator
pdb = app.PDBFile(pdb_file)
calculator = NequIPCalculator.from_deployed_model(model)
atoms_string = "".join([atom.element.symbol for atom in pdb.topology.atoms()])
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.angstrom)
cell = (
    pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.angstrom)
    if pdb.topology.getPeriodicBoxVectors()
    else None
)
atoms = Atoms(
    atoms_string, positions=positions, cell=cell, pbc=False if cell is None else True
)
calculator.calculate(atoms)
pot_energy = calculator.get_potential_energy()
print("NequIPCalculator energy: {}".format(pot_energy * energyScale))

# Calculate the energy using OpenMM
potential = MLPotential(
    "nequip",
    modelPath=model,
    lengthScale=lengthScale,
    energyScale=energyScale,
)

system = potential.createSystem(pdb.topology)
integrator = mm.LangevinIntegrator(
    300 * unit.kelvin, 1.0 / unit.picoseconds, 1.0 * unit.femtosecond
)
simulation = app.Simulation(pdb.topology, system, integrator)
simulation.context.setPositions(pdb.positions)

print(
    "OpenMM-ML energy: {}".format(
        simulation.context.getState(getEnergy=True)
        .getPotentialEnergy()
        .value_in_unit(unit.kilojoules_per_mole)
    )
)

Toluene (NequIP, No PBC)

NequIPCalculator energy: -710491.18525 
OpenMM-ML energy: -710491.1875

Si (Allegro, PBC)

NequIPCalculator energy: -802582.8787801563
OpenMM-ML energy: -802582.875

Values are in kJ/mol. They seem to disagree from the 3rd decimal place onwards. I have checked, and the same input data is being passed to the model.

input_data.zip

peastman commented 2 months ago

They agree to eight significant digits, which is the accuracy of single precision. Do the forces have similar agreement? If so, I think it's fine.

Is there an option to predict a formation energy instead of total energy, or to subtract off per-atom mean energies? That leads to a much smaller output value and better accuracy.

JMorado commented 2 months ago

That's true, thanks for pointing that out. Regarding the forces, this is what I get (values in kJ/mol/nm):

Toluene ``` NequIPCalculator forces: [[ 318.61404 -1153.1539 783.135 ] [ 379.53235 455.95953 -261.0112 ] [ 1114.3433 1182.2357 163.06062 ] [ -266.91818 1380.0348 146.71664 ] [ -857.644 244.97173 -12.799915 ] [ -966.69995 -1690.4469 -168.81912 ] [ 257.23373 -497.74158 -25.056156 ] [ 224.72466 211.90524 -195.86229 ] [ 495.38205 300.3262 -238.29967 ] [ 71.34846 423.93893 -45.857525 ] [ -183.62227 130.88791 -7.647993 ] [ -107.9876 -119.39723 -18.707212 ] [ -366.71957 -163.53647 -29.577057 ] [ -106.718575 -82.88132 -16.851692 ] [ -4.8681235 -623.1028 -72.42232 ]] OpenMM-ML forces: [[ 318.61157227 -1153.1529541 783.13665771] [ 379.53060913 455.96350098 -261.01147461] [ 1114.33947754 1182.234375 163.06036377] [ -266.9078064 1380.03112793 146.71669006] [ -857.64770508 244.96878052 -12.80025387] [ -966.69647217 -1690.44335938 -168.81869507] [ 257.23568726 -497.74768066 -25.05669785] [ 224.72613525 211.90498352 -195.86283875] [ 495.38165283 300.32611084 -238.29972839] [ 71.35070801 423.93704224 -45.85838318] [ -183.62208557 130.88764954 -7.64784527] [ -107.99130249 -119.38985443 -18.70673943] [ -366.71969604 -163.53736877 -29.57713509] [ -106.72071075 -82.88322449 -16.85195541] [ -4.87014723 -623.09869385 -72.42201233]] Difference: [[ 2.47192383e-03 -9.76562500e-04 -1.64794922e-03] [ 1.73950195e-03 -3.96728516e-03 2.74658203e-04] [ 3.78417969e-03 1.34277344e-03 2.59399414e-04] [-1.03759766e-02 3.66210938e-03 -4.57763672e-05] [ 3.72314453e-03 2.94494629e-03 3.38554382e-04] [-3.47900391e-03 -3.54003906e-03 -4.27246094e-04] [-1.95312500e-03 6.10351562e-03 5.41687012e-04] [-1.48010254e-03 2.59399414e-04 5.49316406e-04] [ 3.96728516e-04 9.15527344e-05 6.10351562e-05] [-2.25067139e-03 1.89208984e-03 8.58306885e-04] [-1.83105469e-04 2.59399414e-04 -1.47819519e-04] [ 3.70025635e-03 -7.37762451e-03 -4.73022461e-04] [ 1.22070312e-04 9.00268555e-04 7.82012939e-05] [ 2.13623047e-03 1.90734863e-03 2.63214111e-04] [ 2.02369690e-03 -4.08935547e-03 -3.05175781e-04]] ```
Si ``` NequIPCalculator forces: [[ 67.79043 -28.809738 11.53376 ] [ -4.886628 -25.532581 -15.155592 ] [ 1.4763676 -14.482292 19.402218 ] [ 17.322676 -69.07709 -25.55506 ] [ 19.82048 33.217815 -20.82479 ] [-45.851116 17.135185 -9.709675 ] [ 4.5035777 -24.29101 -11.166489 ] [ 1.9369186 -1.6387768 -1.4257112] [ 13.536634 -22.138472 31.538412 ] [-14.554014 11.717597 -19.121832 ] [-15.295826 35.69589 -3.7766256] [ 27.920511 -61.219616 -35.173405 ] [ 31.38432 44.064106 -10.46437 ] [-15.414515 15.491039 6.3312597] [-10.714798 -2.390285 7.777393 ] [ 9.10093 21.255102 -8.837459 ] [-12.911689 -5.2226615 65.043106 ] [ 35.767906 11.211081 -12.875771 ] [ 60.736385 -18.289862 3.730354 ] [ -8.458305 3.356562 -4.8178754] [-13.878986 18.876963 17.74003 ] [-24.694405 -32.99745 20.24441 ] [-34.090267 -15.701595 10.985336 ] [-16.814713 -11.162299 42.942413 ] [ -3.4840176 -5.062717 -13.725371 ] [-33.165398 5.9761963 -18.1375 ] [ 48.521038 -13.241893 17.688929 ] [-10.354681 -2.148305 -25.099829 ] [-37.356796 44.274803 -34.508373 ] [ 44.269165 -5.420011 -13.364778 ] [ 2.4421818 96.35251 26.062864 ] [ -9.241314 -2.5163426 -11.17031 ] [101.58642 9.112973 6.6917353] [ 45.18976 -18.290195 -23.636248 ] [-68.591415 -39.544487 61.845932 ] [-37.37915 -17.486694 51.09612 ] [-19.860252 0.7089257 -34.855164 ] [-34.38795 43.093174 -24.385368 ] [-96.76243 33.55584 -40.637005 ] [ 25.740808 -18.035166 -53.522533 ] [-17.347645 -14.707738 20.247072 ] [-27.524693 -50.164726 47.998127 ] [ 78.78394 5.531705 24.59482 ] [ 42.468536 46.616627 -52.593685 ] [ -6.6523476 16.796291 -87.7328 ] [ 3.4838438 -38.520428 8.806853 ] [ 2.1332002 37.900658 39.479454 ] [ 36.56353 -38.588394 19.898565 ] [ 27.694416 -80.263596 15.579612 ] [ 35.44082 12.952968 -50.18059 ] [-49.62549 26.632977 29.234938 ] [ 49.63715 -61.17182 70.60961 ] [ 7.865634 -16.822647 -13.332666 ] [-86.48338 88.669395 -49.876156 ] [ 47.17594 18.837576 -2.4321811] [ -6.342099 15.388432 21.146124 ] [ 38.588196 27.882034 34.625492 ] [-20.583471 14.237654 -1.0932204] [-23.4871 72.94298 -3.7524729] [-11.276121 39.70276 -30.83238 ] [-14.0973425 -10.143854 -37.7538 ] [ 7.0738263 -61.706738 -17.025831 ] [-60.275654 -20.867666 60.231632 ] [-44.111538 -21.53065 25.44634 ]] OpenMM-ML forces: [[ 67.79095459 -28.80946541 11.53389549] [ -4.88648033 -25.5328598 -15.15541744] [ 1.47649062 -14.48234749 19.40254593] [ 17.32279015 -69.07706451 -25.55541039] [ 19.82018471 33.21785736 -20.82455444] [-45.85109329 17.13524628 -9.70964622] [ 4.50365496 -24.29071236 -11.16631031] [ 1.9371053 -1.63878345 -1.4259665 ] [ 13.53680134 -22.13811111 31.53835106] [-14.55395412 11.71772671 -19.12192345] [-15.29564571 35.69577789 -3.77657413] [ 27.92069817 -61.21953964 -35.17324066] [ 31.38408089 44.06418228 -10.46452713] [-15.41459084 15.49117565 6.33090544] [-10.71499538 -2.39058208 7.77688837] [ 9.10089684 21.25516891 -8.83750343] [-12.91157341 -5.22240162 65.04360199] [ 35.76815414 11.2110014 -12.87581348] [ 60.73667526 -18.28949165 3.73050475] [ -8.45816994 3.35656095 -4.81774998] [-13.87901878 18.87675095 17.73980713] [-24.69457817 -32.99786758 20.24431419] [-34.09024811 -15.70182705 10.98557568] [-16.81512642 -11.16264534 42.94252014] [ -3.48382878 -5.06294775 -13.72527027] [-33.1651001 5.97644854 -18.13700104] [ 48.52108383 -13.24161148 17.68883514] [-10.35446072 -2.14804649 -25.0998497 ] [-37.35691833 44.27435303 -34.50836182] [ 44.26918411 -5.42039633 -13.36488342] [ 2.44198561 96.35207367 26.06266022] [ -9.24124241 -2.51659489 -11.17035961] [101.58667755 9.11304665 6.69179535] [ 45.1896286 -18.29023361 -23.63647842] [-68.59138489 -39.54423523 61.84625626] [-37.37944412 -17.48635101 51.09622955] [-19.86007118 0.70900166 -34.85506821] [-34.38783264 43.09328079 -24.38536835] [-96.76257324 33.55570602 -40.63653564] [ 25.74025154 -18.03463173 -53.52237701] [-17.34791756 -14.70756245 20.24714851] [-27.52493286 -50.16454697 47.99803543] [ 78.78383636 5.53193188 24.59463501] [ 42.46936798 46.61732101 -52.59399033] [ -6.65289164 16.7964077 -87.73320007] [ 3.48389125 -38.52040863 8.80679893] [ 2.1330018 37.9006958 39.47941971] [ 36.56356049 -38.58841705 19.89873314] [ 27.69441986 -80.26367188 15.57967281] [ 35.44096756 12.95310497 -50.180439 ] [-49.62584305 26.63298416 29.23537064] [ 49.63778305 -61.17185593 70.609375 ] [ 7.86521959 -16.82258034 -13.33259678] [-86.48326874 88.66952515 -49.87618256] [ 47.17575455 18.83693504 -2.43251538] [ -6.34226656 15.38792515 21.14606094] [ 38.58839035 27.88191795 34.62520218] [-20.58321762 14.23798943 -1.09339654] [-23.48657036 72.9430542 -3.75274563] [-11.27591801 39.70261765 -30.83240891] [-14.0974226 -10.14392471 -37.75414658] [ 7.07348776 -61.70677948 -17.02573586] [-60.27626038 -20.86861038 60.23205185] [-44.11212921 -21.53063011 25.44634056]] Difference: [[-5.26428223e-04 -2.72750854e-04 -1.35421753e-04] [-1.47819519e-04 2.78472900e-04 -1.74522400e-04] [-1.23023987e-04 5.53131104e-05 -3.28063965e-04] [-1.14440918e-04 -2.28881836e-05 3.50952148e-04] [ 2.95639038e-04 -4.19616699e-05 -2.34603882e-04] [-2.28881836e-05 -6.10351562e-05 -2.86102295e-05] [-7.72476196e-05 -2.97546387e-04 -1.78337097e-04] [-1.86681747e-04 6.67572021e-06 2.55346298e-04] [-1.66893005e-04 -3.60488892e-04 6.10351562e-05] [-6.00814819e-05 -1.29699707e-04 9.15527344e-05] [-1.80244446e-04 1.10626221e-04 -5.14984131e-05] [-1.86920166e-04 -7.62939453e-05 -1.64031982e-04] [ 2.38418579e-04 -7.62939453e-05 1.57356262e-04] [ 7.62939453e-05 -1.36375427e-04 3.54290009e-04] [ 1.97410583e-04 2.97069550e-04 5.04493713e-04] [ 3.33786011e-05 -6.67572021e-05 4.48226929e-05] [-1.15394592e-04 -2.59876251e-04 -4.95910645e-04] [-2.47955322e-04 7.91549683e-05 4.29153442e-05] [-2.89916992e-04 -3.70025635e-04 -1.50680542e-04] [-1.35421753e-04 9.53674316e-07 -1.25408173e-04] [ 3.24249268e-05 2.11715698e-04 2.23159790e-04] [ 1.73568726e-04 4.15802002e-04 9.53674316e-05] [-1.90734863e-05 2.31742859e-04 -2.39372253e-04] [ 4.13894653e-04 3.46183777e-04 -1.06811523e-04] [-1.88827515e-04 2.30789185e-04 -1.01089478e-04] [-2.97546387e-04 -2.52246857e-04 -4.99725342e-04] [-4.57763672e-05 -2.81333923e-04 9.34600830e-05] [-2.20298767e-04 -2.58445740e-04 2.09808350e-05] [ 1.22070312e-04 4.50134277e-04 -1.14440918e-05] [-1.90734863e-05 3.85284424e-04 1.05857849e-04] [ 1.96218491e-04 4.34875488e-04 2.04086304e-04] [-7.15255737e-05 2.52246857e-04 4.95910645e-05] [-2.59399414e-04 -7.34329224e-05 -6.00814819e-05] [ 1.29699707e-04 3.81469727e-05 2.30789185e-04] [-3.05175781e-05 -2.51770020e-04 -3.24249268e-04] [ 2.93731689e-04 -3.43322754e-04 -1.10626221e-04] [-1.81198120e-04 -7.59363174e-05 -9.53674316e-05] [-1.18255615e-04 -1.06811523e-04 0.00000000e+00] [ 1.44958496e-04 1.33514404e-04 -4.69207764e-04] [ 5.56945801e-04 -5.34057617e-04 -1.56402588e-04] [ 2.72750854e-04 -1.75476074e-04 -7.62939453e-05] [ 2.40325928e-04 -1.79290771e-04 9.15527344e-05] [ 1.06811523e-04 -2.26974487e-04 1.85012817e-04] [-8.31604004e-04 -6.94274902e-04 3.05175781e-04] [ 5.44071198e-04 -1.16348267e-04 3.96728516e-04] [-4.74452972e-05 -1.90734863e-05 5.43594360e-05] [ 1.98364258e-04 -3.81469727e-05 3.43322754e-05] [-3.05175781e-05 2.28881836e-05 -1.67846680e-04] [-3.81469727e-06 7.62939453e-05 -6.10351562e-05] [-1.48773193e-04 -1.37329102e-04 -1.52587891e-04] [ 3.54766846e-04 -7.62939453e-06 -4.32968140e-04] [-6.33239746e-04 3.43322754e-05 2.36511230e-04] [ 4.14371490e-04 -6.67572021e-05 -6.96182251e-05] [-1.14440918e-04 -1.29699707e-04 2.67028809e-05] [ 1.86920166e-04 6.40869141e-04 3.34262848e-04] [ 1.67369843e-04 5.06401062e-04 6.29425049e-05] [-1.94549561e-04 1.16348267e-04 2.89916992e-04] [-2.53677368e-04 -3.35693359e-04 1.76191330e-04] [-5.30242920e-04 -7.62939453e-05 2.72750854e-04] [-2.03132629e-04 1.41143799e-04 2.86102295e-05] [ 8.01086426e-05 7.05718994e-05 3.47137451e-04] [ 3.38554382e-04 4.19616699e-05 -9.53674316e-05] [ 6.06536865e-04 9.44137573e-04 -4.19616699e-04] [ 5.91278076e-04 -1.90734863e-05 0.00000000e+00]] ```

As far as I know, there's no option to get the interaction energy or the per-atom mean energies. This atomic energy sums up to the total energy

JMorado commented 2 months ago

This is done from my side. If someone could take a look and review the changes, that would be great. Performance benchmarks on test models can be found here.

Many thanks!

svarner9 commented 1 month ago

@svarner9, could you try the current implementation available here? It uses the NNPOps neighbor list, so I anticipate it might be slightly faster for a system of the size you're working with. You can create the MLPotential using something along these lines:

potential = MLPotential('nequip', modelPath='model.pth', lengthScale=0.1, energyScale=4.184)

What speed-up did you observe in your LAMMPS simulations compared to OpenMM/OpenMM-ML?

@JMorado I went ahead and tested out the version on the nequip branch, however I am unable to get it to run on a GPU. When I specify the potential and the platform in the following way,

potential = MLPotential("nequip",
                            modelPath='model.pth',
                            lengthScale=0.1,
                            energyScale=96.48)
...

plat = openmm.Platform.getPlatformByName("CUDA")
properties = {"Precision": "double", "DeviceIndex": "0",
              "UseBlockingSync": "false"}
simulation = app.Simulation(topology, system, integrator, plat, properties)

I get the following set of warnings and errors:

/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/torchani/aev.py:16: UserWarning: cuaev not installed
  warnings.warn("cuaev not installed")
/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/scripts/deploy.py:138: UserWarning: Models deployed before v0.6.0 don't contain information about their default_dtype or model_dtype; assuming the old default of float32 for both, but this might not be right if you had explicitly set default_dtype=float64.
  warnings.warn(
/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/utils/_global_options.py:59: UserWarning: !! Upstream issues in PyTorch versions >1.11 have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. At present we *strongly* recommend the use of PyTorch 1.11 if using CUDA devices; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue.
  warnings.warn(
/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/utils/_global_options.py:70: UserWarning: Setting the GLOBAL value for jit fusion strategy to `[('DYNAMIC', 3)]` which is different than the previous value of `[('STATIC', 2), ('DYNAMIC', 10)]`
  warnings.warn(
Traceback (most recent call last):
  File "/home/svarner/Practicum/sim.py", line 174, in <module>
    run(1,1,1,1,1)
  File "/home/svarner/Practicum/sim.py", line 145, in run
    simulation = app.Simulation(topology, system, integrator, plat, properties)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/openmm/app/simulation.py", line 106, in __init__
    self.context = mm.Context(self.system, self.integrator, platform, platformProperties)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/openmm/openmm.py", line 12171, in __init__
    _openmm.Context_swiginit(self, _openmm.new_Context(*args))
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
openmm.OpenMMException: Specified a Platform for a Context which does not support all required kernels

Here is my mamba list:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
ase                       3.22.1             pyhd8ed1ab_1    conda-forge
blinker                   1.8.2              pyhd8ed1ab_0    conda-forge
brotli                    1.1.0                hd590300_1    conda-forge
brotli-bin                1.1.0                hd590300_1    conda-forge
brotli-python             1.1.0           py311hb755f60_1    conda-forge
bzip2                     1.0.8                hd590300_5    conda-forge
c-ares                    1.28.1               hd590300_0    conda-forge
ca-certificates           2024.2.2             hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
certifi                   2024.2.2           pyhd8ed1ab_0    conda-forge
charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
click                     8.1.7           unix_pyh707e725_0    conda-forge
contourpy                 1.2.1           py311h9547e67_0    conda-forge
cudatoolkit               11.5.2              hbdc67f6_13    conda-forge
cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
e3nn                      0.5.1                    pypi_0    pypi
filelock                  3.14.0             pyhd8ed1ab_0    conda-forge
flask                     3.0.3              pyhd8ed1ab_0    conda-forge
fonttools                 4.51.0          py311h459d7ec_0    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
fsspec                    2024.3.1           pyhca7485f_0    conda-forge
gmp                       6.3.0                h59595ed_1    conda-forge
gmpy2                     2.1.5           py311he48d604_0    conda-forge
h5py                      3.11.0          nompi_py311hebc2b07_100    conda-forge
hdf5                      1.14.3          nompi_h4f84152_101    conda-forge
idna                      3.7                pyhd8ed1ab_0    conda-forge
importlib-metadata        7.1.0              pyha770c72_0    conda-forge
importlib_metadata        7.1.0                hd8ed1ab_0    conda-forge
itsdangerous              2.2.0              pyhd8ed1ab_0    conda-forge
jinja2                    3.1.3              pyhd8ed1ab_0    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.5           py311h9547e67_1    conda-forge
krb5                      1.21.2               h659d440_0    conda-forge
lark-parser               0.12.0             pyhd8ed1ab_0    conda-forge
lcms2                     2.16                 hb7c19ff_0    conda-forge
ld_impl_linux-64          2.40                 h55db66e_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libabseil                 20230802.1      cxx17_h59595ed_0    conda-forge
libaec                    1.1.3                h59595ed_0    conda-forge
libblas                   3.9.0           22_linux64_openblas    conda-forge
libbrotlicommon           1.1.0                hd590300_1    conda-forge
libbrotlidec              1.1.0                hd590300_1    conda-forge
libbrotlienc              1.1.0                hd590300_1    conda-forge
libcblas                  3.9.0           22_linux64_openblas    conda-forge
libcurl                   8.7.1                hca28451_0    conda-forge
libdeflate                1.20                 hd590300_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 hd590300_2    conda-forge
libexpat                  2.6.2                h59595ed_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 13.2.0               h77fa898_7    conda-forge
libgfortran-ng            13.2.0               h69a702a_7    conda-forge
libgfortran5              13.2.0               hca663fb_7    conda-forge
libgomp                   13.2.0               h77fa898_7    conda-forge
libjpeg-turbo             3.0.0                hd590300_1    conda-forge
liblapack                 3.9.0           22_linux64_openblas    conda-forge
libnghttp2                1.58.0               h47da74e_1    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libopenblas               0.3.27          pthreads_h413a1c8_0    conda-forge
libpng                    1.6.43               h2797004_0    conda-forge
libprotobuf               4.25.1               hf27288f_2    conda-forge
libsqlite                 3.45.3               h2797004_0    conda-forge
libssh2                   1.11.0               h0841786_0    conda-forge
libstdcxx-ng              13.2.0               hc0a3c3a_7    conda-forge
libtiff                   4.6.0                h1dd3fc0_3    conda-forge
libtorch                  2.1.2           cpu_generic_ha017de0_3    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libuv                     1.48.0               hd590300_0    conda-forge
libwebp-base              1.4.0                hd590300_0    conda-forge
libxcb                    1.15                 h0b41bf4_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
markupsafe                2.1.5           py311h459d7ec_0    conda-forge
matplotlib-base           3.8.4           py311h54ef318_0    conda-forge
mpc                       1.3.1                hfe3b2da_0    conda-forge
mpfr                      4.2.1                h9458935_1    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
ncurses                   6.4.20240210         h59595ed_0    conda-forge
nequip                    0.6.0                    pypi_0    pypi
networkx                  3.3                pyhd8ed1ab_1    conda-forge
nnpops                    0.6             cpu_py311h7697b17_7    conda-forge
nomkl                     1.0                  h5ca1d4c_0    conda-forge
numpy                     1.26.4          py311h64a7726_0    conda-forge
ocl-icd                   2.3.2                hd590300_1    conda-forge
ocl-icd-system            1.0.0                         1    conda-forge
openjpeg                  2.5.2                h488ebb8_0    conda-forge
openmm                    8.1.1           py311h28d7ac7_1    conda-forge
openmm-torch              1.4             cpu_py311h446247e_4    conda-forge
openmmml                  1.1                      pypi_0    pypi
openssl                   3.3.0                hd590300_0    conda-forge
opt-einsum                3.3.0                    pypi_0    pypi
opt-einsum-fx             0.1.4                    pypi_0    pypi
packaging                 24.0               pyhd8ed1ab_0    conda-forge
pillow                    10.3.0          py311h18e6fac_0    conda-forge
pip                       24.0               pyhd8ed1ab_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
pyparsing                 3.1.2              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
python                    3.11.9          hb806964_0_cpython    conda-forge
python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
python_abi                3.11                    4_cp311    conda-forge
pytorch                   2.1.2           cpu_generic_py311h1584bb0_3    conda-forge
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h8228510_1    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
scipy                     1.13.0          py311h517d4fd_1    conda-forge
setuptools                65.3.0             pyhd8ed1ab_1    conda-forge
setuptools-scm            6.3.2              pyhd8ed1ab_0    conda-forge
setuptools_scm            6.3.2                hd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h9b69904_2    conda-forge
sympy                     1.12            pypyh9d50eac_103    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
torch-ema                 0.3                      pypi_0    pypi
torch-runstats            0.2.0                    pypi_0    pypi
torchani                  2.2.4           cpu_py311h12a0d1d_3    conda-forge
tqdm                      4.66.4                   pypi_0    pypi
typing_extensions         4.11.0             pyha770c72_0    conda-forge
tzdata                    2024a                h0c530f3_0    conda-forge
urllib3                   2.2.1              pyhd8ed1ab_0    conda-forge
werkzeug                  3.0.3              pyhd8ed1ab_0    conda-forge
wheel                     0.43.0             pyhd8ed1ab_1    conda-forge
xorg-libxau               1.0.11               hd590300_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zipp                      3.17.0             pyhd8ed1ab_0    conda-forge
zstd                      1.5.6                ha6fb4c9_0    conda-forge

If I don't specify any platform, then the simulation runs, but extremely slowly since it is on CPU.

Thank you so much in advance!

Best, Sam

peastman commented 1 month ago

That means a plugin couldn't be loaded. Try printing the value of Platform.getPluginLoadFailures(). It will tell you which ones failed, and what the errors were.

Usually it's because some library they depended on couldn't be found, and it can be fixed by adding the directory containing the library to LD_LIBRARY_PATH.

svarner9 commented 1 month ago

That means a plugin couldn't be loaded. Try printing the value of Platform.getPluginLoadFailures(). It will tell you which ones failed, and what the errors were.

Usually it's because some library they depended on couldn't be found, and it can be fixed by adding the directory containing the library to LD_LIBRARY_PATH.

Thank you for the quick response!

I tried that based on some previous replies of yours that I found. I ran the following:

print(pluginLoadedLibNames)
print(Platform.getPluginLoadFailures())

and the output was:

('/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMPME.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMCPU.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMCUDA.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMOpenCL.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMRPMDCUDA.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMDrudeCUDA.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMAmoebaCUDA.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMRPMDOpenCL.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMTorchOpenCL.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMDrudeOpenCL.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMAmoebaOpenCL.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMRPMDReference.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMTorchReference.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMDrudeReference.so', '/home/svarner/miniconda3/envs/practicum/lib/plugins/libOpenMMAmoebaReference.so')

()

The failures command returned an empty tuple.

Best, Sam

peastman commented 1 month ago

The versions of PyTorch and OpenMM-Torch you have installed are CPU only:

openmm-torch              1.4             cpu_py311h446247e_4    conda-forge
pytorch                   2.1.2           cpu_generic_py311h1584bb0_3    conda-forge

That might be because you have an older version of cudatoolkit:

cudatoolkit               11.5.2              hbdc67f6_13    conda-forge

If you upgrade it to 11.8, you might be able to get it to install the CUDA version of PyTorch. Conda installation issues like this tend to be frustrating and hard to figure out. They often depend on the precise order you install packages in.

svarner9 commented 1 month ago

The versions of PyTorch and OpenMM-Torch you have installed are CPU only:

openmm-torch              1.4             cpu_py311h446247e_4    conda-forge
pytorch                   2.1.2           cpu_generic_py311h1584bb0_3    conda-forge

That might be because you have an older version of cudatoolkit:

cudatoolkit               11.5.2              hbdc67f6_13    conda-forge

If you upgrade it to 11.8, you might be able to get it to install the CUDA version of PyTorch. Conda installation issues like this tend to be frustrating and hard to figure out. They often depend on the precise order you install packages in.

Ahhh I see. Thank you!

I went ahead an uninstalled openmm-torch and pytorch. I upgraded the cudatoolkit, and then installed the cuda version of pytorch:

install pytorch pytorch-cuda=11.8 -c pytorch -c nvidia

Installing openmm-torch downgraded it back to the cpu version, but then installing nnpops upgraded it back to the cuda version. I agree, conda installations are very frustrating.

It is working on GPU now, but only getting about 0.2 ns/day, whereas on lammps I was getting 1.5 ns/day. To your knowledge, could any of the following warnings have to do with it being slow?

/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/scripts/deploy.py:138: UserWarning: Models deployed before v0.6.0 don't contain information about their default_dtype or model_dtype; assuming the old default of float32 for both, but this might not be right if you had explicitly set default_dtype=float64.
  warnings.warn(
/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/utils/_global_options.py:59: UserWarning: !! Upstream issues in PyTorch versions >1.11 have been seen to cause unusual performance degredations on some CUDA systems that become worse over time; see https://github.com/mir-group/nequip/discussions/311. At present we *strongly* recommend the use of PyTorch 1.11 if using CUDA devices; while using other versions if you observe this problem, an unexpected lack of this problem, or other strange behavior, please post in the linked GitHub issue.
  warnings.warn(
/home/svarner/miniconda3/envs/practicum/lib/python3.11/site-packages/nequip/utils/_global_options.py:70: UserWarning: Setting the GLOBAL value for jit fusion strategy to `[('DYNAMIC', 3)]` which is different than the previous value of `[('STATIC', 2), ('DYNAMIC', 10)]`
  warnings.warn(

I tried to install the packages in such a way to allow me to use pytorch 1.11.0 (which according to the error is the most stable version with nequip), however, as far as I can tell there is no way to use pytorch 1.11.0 with openmm-torch. Every time I would install openmm-torch it would install pytorch 2.1.2.

This is the order that I did everything:

mamba create -n env
mamba activate env
mamba install python=3.10
mamba install -c conda-forge openmm cudatoolkit=11.8
pip install git+https://github.com/mir-group/nequip@develop
pip install git+https://github.com/sef43/openmm-ml@nequip
mamba install pytorch=1.11 pytorch-cuda=11.8 -c pytorch -c nvidia
mamba install -c conda-forge openmm-torch nnpops
JMorado commented 1 month ago

Many thanks for the thorough review, @peastman! Most of it should be now resolved.

Thanks for testing, @svarner9. I think the slow performance you're seeing is not related to that warning, the underlying issue of which is described here. You could test if the issue that underlies that warning is indeed present by identifying a slowdown in performance over time. I ran some performance benchmarks on systems much smaller than yours and did not see any decrease in performance over time, and the simulation speed is around what I would expect.

If that is your baseline OpenMM performance, I wonder what could be causing that. Do you remember by any chance what was the performance you were getting with the previous neighbor list? Does anyone have any ideas about whether it's possible to improve performance here?

svarner9 commented 1 month ago

Yes many thanks @peastman for the help!

@JMorado I am not sure, but there are a few things I can think of that might be the issue, but I am not an expert and have not looked through the code, so it might be a bit naive.

  1. In LAMMPS the nequip pairstyle works with Kokkos, so in that case I was using 1 gpu + 32 cpus.
    mpiexec -n 1 ./lmp -in in.script -k on g 1 t 32 -sf kk -pk kokkos newton on neigh full
  2. The LAMMPS nequip pairstyle uses libtorch instead of pytorch, which could make a difference?
  3. When reading in the model, is the cutoff set to the cutoff of the MLP? Most of them have very short cutoffs of around 5 Angstroms, so if that cutoff is not being used for neighborlists, then that could be leading to slow performance. Is that something that should be set separately?
  4. I am getting this warning for jit but I am not sure if it is important or could be affecting performance. I have seen the NequIP devs say that it can usually be silently ignored.
    /home/svarner/miniconda3/envs/practicum/lib/python3.10/site-packages/nequip/utils/_global_options.py:70: UserWarning: Setting the GLOBAL value for jit fusion strategy to `[('DYNAMIC', 3)]` which is different than the previous value of `[('STATIC', 2), ('DYNAMIC', 10)]`
    warnings.warn(

Best, Sam

Linux-cpp-lisp commented 1 month ago

Is there an option to predict a formation energy instead of total energy, or to subtract off per-atom mean energies? That leads to a much smaller output value and better accuracy.

We actually do this internally, at least from develop onward---single precision calculations are done in a more numerically favorable range, and the final energy scalings, shiftings, and sums are done in float64, regardless of the precision of the weights. The final predictions you get should be float64, and if they aren't, something might be off.

Regarding the reproducibility of energies between ASE and OpenMM: you can try turning off TF32, or even better using a fully F64 model (default_dtype: float64 and model_dtype: float64) to ensure that this is just numerics as a sanity check.

Linux-cpp-lisp commented 1 month ago

@svarner9 a few questions on performance:

peastman commented 1 month ago

There shouldn't be any overhead from Python. The model gets compiled to torchscript, and the simulation gets run by C++ code.

Linux-cpp-lisp commented 1 month ago

Do you call TorchScript from Python here, or directly from C++? Not that I would expect a roundtrip through Python to matter much, just curious.

peastman commented 1 month ago

It's called directly from C++.

JMorado commented 1 month ago

@peastman @Linux-cpp-lisp, I've trained a model with these settings:

default_dtype: float64
model_dtype: float64
allow_tf32: true  

and the energy and force differences between ASE and OpenMM are indeed very small, on the order of $10^{−10}$, when combined with {"Precision": "double"} in the simulation settings.

Linux-cpp-lisp commented 1 month ago

@JMorado great!

(Note that allow_tf32: true is a no-op when model_dtype: float64 and we should probably error on this configuration, but that doesn't change the results.)

svarner9 commented 1 month ago

@svarner9 a few questions on performance:

  • What are the actual LAMMPS vs OpenMM numbers? Not sure where they were in this thread.
  • Yes, there will be additional Python and doubled neighborlist overhead in OpenMM, both of which are absent in pair_allegro. This should be more important for smaller models and smaller systems.
  • You can ignore that particular warning about the fusion strategy safely, it is just there to ensure that nequip never silently sets global state when called from someone else's program

@Linux-cpp-lisp I was getting 1.5 ns/day on lammps and 0.2 ns/day on openmm for a system with 645 atoms.