openmm / openmm-ml

High level API for using machine learning models in OpenMM simulations
Other
83 stars 25 forks source link

Simulation performance steadily declines #62

Closed wiederm closed 1 year ago

wiederm commented 1 year ago

I ran a pure waterbox simulation within a 25 Angstrom box with Ani-2x and the torchani implementation. The simulation performance decreased from ~2ns/day to 0.3 ns/day within the first 1ns of simulation time.

I've attached a script to reproduce this behavior and the output of the StateReporter. Is there anything that I need to do differently?

#"Step","Time (ps)","Potential Energy (kJ/mole)","Total Energy (kJ/mole)","Temperature (K)","Density (g/mL)","Speed (ns/day)"
500,0.5000000000000003,-100516778.12917194,-100514547.79711843,118.98305105506671,0.9591993474692536,0
1000,1.0000000000000007,-100515425.77060814,-100511819.93468693,192.36299762377715,0.9591993474692536,2.06
1500,1.4999999999999456,-100514835.4803809,-100510368.56665199,238.29950496707687,0.9591993474692536,1.77
2000,1.9999999999998905,-100514522.33117871,-100509525.64425702,266.5616781910217,0.9591993474692536,1.86
2500,2.4999999999998357,-100514196.04165852,-100508978.46768452,278.3454910015101,0.9591993474692536,1.9
3000,2.9999999999997806,-100514152.50934735,-100508801.46777023,285.465678603462,0.9591993474692536,1.93
3500,3.4999999999997256,-100513990.48632172,-100508530.74921964,291.26433319200413,0.9591993474692536,1.95
4000,3.9999999999996705,-100514099.18439446,-100508432.39956689,302.309849959495,0.9591993474692536,1.96
4500,4.4999999999998375,-100513837.67253573,-100508312.70271303,294.74434780102064,0.9591993474692536,1.97
...
995500,995.4999999833715,-100515995.70435952,-100510307.35485645,303.4602754673055,0.9591993474692536,0.304
996000,995.9999999833597,-100515819.44682398,-100510114.86655764,304.32614910742063,0.9591993474692536,0.304
996500,996.4999999833478,-100515949.61809933,-100510331.93236865,299.6905268216136,0.9591993474692536,0.304
997000,996.999999983336,-100516013.21894117,-100510187.2782918,310.80044455015786,0.9591993474692536,0.304
997500,997.4999999833242,-100515947.76273051,-100510117.73721096,311.01836291642246,0.9591993474692536,0.304
998000,997.9999999833124,-100516183.55732429,-100510327.88930751,312.38633077208556,0.9591993474692536,0.304
998500,998.4999999833005,-100515961.05077694,-100510513.25682221,290.6271939333003,0.9591993474692536,0.304
999000,998.9999999832887,-100515947.95052087,-100510372.42897089,297.4411654058028,0.9591993474692536,0.304
999500,999.4999999832769,-100515917.83395359,-100510158.53726706,307.2451434368191,0.9591993474692536,0.303
1000000,999.9999999832651,-100515956.23833577,-100510335.1903263,299.8698966099735,0.9591993474692536,0.303

waterbox_simulation_1ns.zip

peastman commented 1 year ago

This is likely the same issue as https://github.com/openmm/openmm/issues/4277. I'll try your script and see what I can tell.

peastman commented 1 year ago

When I first tried to run your script it died, I think from running out of memory. I reduced the box size from 30 to 20 A and tried again. It's now past 100,000 steps and hasn't shown the slightest change in speed. It's completely constant at 3.27 ns/day.

What kind of GPU are you running on?

How much memory do you have (both main memory and GPU memory)?

Does top or nvidia-smi show any obvious change as the simulation runs and it slows down?

peastman commented 1 year ago

Your script doesn't specify which implementation to use, which means it defaults to NNPOps. I tried switching to TorchANI. The results are different in a few ways.

I assume this doesn't match the behavior you see?

wiederm commented 1 year ago

Oh, sorry, this was my mistake. I just rerun the calculation with the script with a 20A waterbox and TorchANI specified. I ran it with and without barostat, and I see in both cases a steady decline, but without barostat, performance seems to fluctuate more. Just to be sure, the script that I used is attached again, and here is the output:

NPT simulation:

#"Step","Time (ps)","Potential Energy (kJ/mole)","Total Energy (kJ/mole)","Temperature (K)","Density (g/mL)","Speed (ns/day)"
500,0.5000000000000003,-51758323.17948345,-51754248.0107705,422.16204849531186,0.9722872069973723,0
1000,1.0000000000000007,-51759147.95102386,-51755283.87097654,400.29457998769965,1.016773100888847,3.13
1500,1.4999999999999456,-51759820.19672352,-51756138.66579172,381.3836307754993,1.0510115155220032,3.4
2000,1.9999999999998905,-51760338.777320854,-51756900.16161244,356.21912948962114,1.0664072542341871,3.42
2500,2.4999999999998357,-51760758.35358425,-51757528.17420589,334.62642639627956,1.0925431575750892,3.47
3000,2.9999999999997806,-51760991.84962394,-51757858.954556264,324.5483788896475,1.0970654747606259,3.51
3500,3.4999999999997256,-51761291.78339195,-51758208.74261621,319.3837853504949,1.0859991266081934,3.53
4000,3.9999999999996705,-51761293.932965696,-51758373.16521566,302.5733125144116,1.0781928064455553,3.54
4500,4.4999999999998375,-51761389.69854176,-51758342.53728326,315.66689126187975,1.0740626163771807,3.55
5000,5.000000000000004,-51761451.32758472,-51758394.36149901,316.68261017555335,1.0854896558779137,3.52
5500,5.500000000000171,-51761630.56347548,-51758519.43933061,322.29304714396596,1.0743636728322803,3.52
6000,6.000000000000338,-51761781.32659065,-51758733.8825156,315.6961892914457,1.0772504814989818,3.5
6500,6.500000000000505,-51761817.164504476,-51758961.6006041,295.8186005794938,1.0690523989616918,3.47
7000,7.000000000000672,-51761891.047494926,-51759135.90546453,285.4155215101733,1.073426970059492,3.45
7500,7.500000000000839,-51761926.23315021,-51759082.5611561,294.5866732989508,1.0683137639301143,3.43
8000,8.000000000001005,-51762092.89459793,-51759113.10243226,308.6878736463885,1.0815017004564533,3.4
8500,8.500000000000728,-51761924.87605181,-51758889.84905224,314.4098577572619,1.0761943966069705,3.38
9000,9.000000000000451,-51762078.63254545,-51759074.950417206,311.16272468847677,1.089365980809329,3.35
9500,9.500000000000174,-51762063.436548874,-51758964.83066412,320.99623351305905,1.0809373591524567,3.34
10000,9.999999999999897,-51761912.49315495,-51759083.97252353,293.0170866729598,1.0715416940543543,3.32
10500,10.49999999999962,-51762129.75033555,-51759086.39102816,315.27303285226253,1.0828157972691668,3.3
11000,10.999999999999343,-51762026.00116498,-51759075.56480984,305.64679486403946,1.0783165303335078,3.28
11500,11.499999999999066,-51761999.11459202,-51759052.53345661,305.24741815736957,1.0737631483220813,3.26
12000,11.999999999998789,-51761998.99440619,-51759100.89004315,300.22552704853996,1.0787897255409897,3.25
12500,12.499999999998511,-51761964.03284741,-51758971.40169781,310.0179122627973,1.0951528020515506,3.23
13000,12.999999999998234,-51762093.12870992,-51759303.36463113,289.0021496719478,1.0884880921900235,3.21
13500,13.499999999997957,-51762085.62460682,-51759316.34713682,286.87986484414796,1.0920590340863168,3.19
14000,13.99999999999768,-51762279.392968185,-51759311.74913634,307.4293820706241,1.0900549230218977,3.17
14500,14.499999999997403,-51762098.478231534,-51759240.743231185,296.0435129886504,1.0916635431619834,3.15
15000,14.999999999997126,-51762110.05237789,-51759251.490792066,296.12914208531856,1.0811324276027054,3.13
15500,15.499999999996849,-51762049.05431055,-51759212.76903817,293.8214550238599,1.0707322915558575,3.12
16000,15.999999999996572,-51761996.903673425,-51759057.07185785,304.54823073983704,1.0699014714110708,3.1

NVT simulation

#"Step","Time (ps)","Potential Energy (kJ/mole)","Total Energy (kJ/mole)","Temperature (K)","Density (g/mL)","Speed (ns/day)"
500,0.5000000000000003,-51758194.96999148,-51754110.56663234,423.1187002142445,0.9647635652408765,0
1000,1.0000000000000007,-51758869.87605468,-51755264.09178782,373.53674896744985,0.9647635652408765,4.21
1500,1.4999999999999456,-51759697.61843871,-51756077.366258785,375.03553447734737,0.9647635652408765,3.96
2000,1.9999999999998905,-51760061.87791938,-51756544.692610994,364.35844975134404,0.9647635652408765,3.96
2500,2.4999999999998357,-51760482.24791007,-51757040.891249895,356.5030749874752,0.9647635652408765,3.88
3000,2.9999999999997806,-51760583.10510897,-51757357.569603704,334.14535012254254,0.9647635652408765,3.86
3500,3.4999999999997256,-51761014.60355686,-51757740.243185736,339.203301544475,0.9647635652408765,3.84
4000,3.9999999999996705,-51761257.93230022,-51758068.37618068,330.4181102823905,0.9647635652408765,3.83
4500,4.4999999999998375,-51761460.67954505,-51758407.55896859,316.2842394149589,0.9647635652408765,3.84
5000,5.000000000000004,-51761590.38009255,-51758451.59298323,325.15875763559785,0.9647635652408765,3.87
5500,5.500000000000171,-51761666.11719989,-51758609.040390864,316.69408042014265,0.9647635652408765,3.88
6000,6.000000000000338,-51761717.410261005,-51758757.05750706,306.674071898325,0.9647635652408765,3.88
6500,6.500000000000505,-51761806.5618603,-51758819.406374946,309.45066761600856,0.9647635652408765,3.89
7000,7.000000000000672,-51761681.78392433,-51758709.4174555,307.9186177909479,0.9647635652408765,3.88
7500,7.500000000000839,-51761769.14149964,-51758690.69200014,318.90815778133816,0.9647635652408765,3.88
8000,8.000000000001005,-51761761.9754192,-51758795.246857755,307.334565768534,0.9647635652408765,3.87
8500,8.500000000000728,-51761901.913045615,-51758883.636567205,312.67460829319657,0.9647635652408765,3.86
9000,9.000000000000451,-51761804.31964331,-51758896.54007997,301.22781742293137,0.9647635652408765,3.85
9500,9.500000000000174,-51761919.1522014,-51759037.87283451,298.4825624389563,0.9647635652408765,3.83
10000,9.999999999999897,-51761998.82163905,-51758958.40976757,314.9676968846035,0.9647635652408765,3.81
10500,10.49999999999962,-51761941.45668937,-51759198.40106706,284.16344505440196,0.9647635652408765,3.79
11000,10.999999999999343,-51761971.15636204,-51759049.90990825,302.6229032366831,0.9647635652408765,3.76
11500,11.499999999999066,-51762070.391052164,-51759186.4535425,298.7579294429007,0.9647635652408765,3.73
12000,11.999999999998789,-51762032.230797455,-51759115.55327353,302.14959061383684,0.9647635652408765,3.69
12500,12.499999999998511,-51761948.40618493,-51758908.67764184,314.89690832682476,0.9647635652408765,3.68
13000,12.999999999998234,-51762031.830178,-51758982.41484222,315.90039960228256,0.9647635652408765,3.66
13500,13.499999999997957,-51762261.039589554,-51759312.97191428,305.4014144106109,0.9647635652408765,3.66
14000,13.99999999999768,-51762198.87847388,-51759312.88252699,298.97117069423706,0.9647635652408765,3.66
14500,14.499999999997403,-51762082.7063445,-51759253.267198145,293.11223908914496,0.9647635652408765,3.65
15000,14.999999999997126,-51762182.60581253,-51759269.95476899,301.7324723714309,0.9647635652408765,3.65
15500,15.499999999996849,-51762250.7036077,-51759388.79308473,296.476071075977,0.9647635652408765,3.64

I am running this on a NVIDIA GeForce RTX 3080 Ti with CUDA Version: 12.2. Nothing else is running on this machine, CPU utilization seems 100% and nvidia-smi doesn't show anything suspicious. waterbox_simulation.zip

peastman commented 1 year ago

Your script throws an exception. In the line

barostate = MonteCarloBarostat(1, unit.atmosphere, temperature)

I assume it's supposed to be 1*unit.atmosphere? Please provide the exact script you're running so I know we're really doing the same thing!

wiederm commented 1 year ago

Ah, sorry, that's it:

from sys import stdout

from openmm import LangevinIntegrator, MonteCarloBarostat, Platform
from openmm.app import Simulation, StateDataReporter
from openmmml import MLPotential
from openmmtools.testsystems import WaterBox
from simtk import unit

# define units
distance_unit = unit.angstrom
time_unit = unit.femto * unit.seconds
speed_unit = distance_unit / time_unit

# define simulation parameters
stepsize = 1 * time_unit
collision_rate = unit.pico * unit.second
temperature = 300 * unit.kelvin
pressure = 1 * unit.atmosphere

# setup waterbox
waterbox = WaterBox(20.0 * distance_unit)

# setup simulation
integrator = LangevinIntegrator(temperature, 1 / collision_rate, stepsize)
platform = Platform.getPlatformByName("CUDA")

potential = MLPotential("ani2x")
system = potential.createSystem(
    waterbox.topology,
    implementation="torchani",
    removeConstraints=True,
    constraints=None,
    rigidWater=False,
)

barostate = MonteCarloBarostat(pressure, temperature)
system.addForce(barostate)

sim = Simulation(
    waterbox.topology,
    system,
    integrator,
    platform=platform,
    platformProperties={
        "Precision": "mixed",
        "DeviceIndex": str(0),
    },
)

reporter = StateDataReporter(
    stdout,
    reportInterval=500,
    step=True,
    time=True,
    potentialEnergy=True,
    totalEnergy=True,
    temperature=True,
    density=True,
    speed=True,
)
sim.reporters.append(reporter)

# run simulation
sim.context.setPositions(waterbox.positions)
sim.step(5_000_000)
peastman commented 1 year ago

Here is what I get on RTX 4080.

#"Step","Time (ps)","Potential Energy (kJ/mole)","Total Energy (kJ/mole)","Temperature (K)","Density (g/mL)","Speed (ns/day)"
500,0.5000000000000003,-51691082.50586787,-51614946.36875268,7887.228694831465,1.0416676947529535,0
1000,1.0000000000000007,-51639532.28212098,-51477304.840114154,16805.75064833584,1.0797802935271492,4
1500,1.4999999999999456,-51594129.21832637,-51333976.21989117,26950.22721208272,1.1064151053966407,4.11
2000,1.9999999999998905,-51551986.56589308,-51137845.278570846,42902.45301177477,1.1597543974695874,4.02
2500,2.4999999999998357,-51595905.6150527,-51312451.371840596,29364.090764834815,1.090852048078292,3.19
[W manager.cpp:331] Warning: FALLBACK path has been taken inside: runCudaFusionGroup. This is an indication that codegen Failed for some reason.
To debug try disable codegen fallback path via setting the env variable `export PYTORCH_NVFUSER_DISABLE=fallback`
 (function runCudaFusionGroup)
3000,2.9999999999997806,-51620080.505107574,-51494564.616022006,13002.662855816243,1.023385364269465,3.03
3500,3.4999999999997256,-51626752.912215255,-51572319.84621575,5638.925960341917,0.96949244738708,3.19
4000,3.9999999999996705,-51627244.09169403,-51604264.13947513,2380.5796486960476,0.9011221550655646,3.36
4500,4.4999999999998375,-51627791.69842257,-51617069.9321141,1110.7080828061412,0.8692633995466944,3.49
5000,5.000000000000004,-51627973.3993749,-51622049.39384985,613.6900049821345,0.8099132806565384,3.61
5500,5.500000000000171,-51628092.002763584,-51623949.9003647,429.095960673766,0.7598003194183287,3.7
6000,6.000000000000338,-51628160.3083801,-51624603.63410665,368.4492601098593,0.7528548301615088,3.79
6500,6.500000000000505,-51628103.29021663,-51624267.07678878,397.4077721054026,0.7376633356855503,3.85
7000,7.000000000000672,-51628147.158046596,-51624935.62667114,332.69460966375715,0.7107289420240409,3.92
7500,7.500000000000839,-51628171.52572475,-51625180.4110375,309.8608162257035,0.689823157526972,3.97
8000,8.000000000001005,-51628170.123556666,-51625364.50816604,290.64424666207907,0.6710355955865158,4.03
8500,8.500000000000728,-51628158.36537576,-51625362.77496039,289.60572178352726,0.6575651030436473,4.07
9000,9.000000000000451,-51628175.632074125,-51625243.9250579,303.7065522274509,0.6355202312152621,4.12
9500,9.500000000000174,-51628156.39232496,-51625233.54802592,302.7884300331938,0.6215680603588196,4.15
10000,9.999999999999897,-51628141.278956145,-51625177.47368027,307.03173162524956,0.6046996660030585,4.19
10500,10.49999999999962,-51628159.96785357,-51625282.258669235,298.1127138054723,0.5903070454303124,4.22
11000,10.999999999999343,-51628416.043807104,-51625427.551242776,309.5891806540922,0.5736186242619516,4.25
11500,11.499999999999066,-51628432.60942143,-51625444.759497255,309.52260714501006,0.5600480870208163,4.28
12000,11.999999999998789,-51628436.71577081,-51625621.17890174,291.67204991597606,0.54515584367781,4.3
12500,12.499999999998511,-51628419.52919633,-51625525.122295246,299.8424930620443,0.5313981087867292,4.32
13000,12.999999999998234,-51628259.4616944,-51625172.79263883,319.75965249853135,0.5250375844533886,4.33
13500,13.499999999997957,-51628261.17434256,-51625399.31640414,296.47062364447595,0.5098066026853267,4.35
14000,13.99999999999768,-51628263.84847739,-51625464.884610236,289.9551903380995,0.5000310395252776,4.37
14500,14.499999999997403,-51628257.959371455,-51625435.071681194,292.4335491205754,0.48796753584440716,4.39
15000,14.999999999997126,-51628260.012546144,-51625420.43842102,294.1621596441675,0.47440999093671243,4.4
15500,15.499999999996849,-51628288.897208616,-51625333.946627244,306.11444053116963,0.4667819504638678,4.42

It's the same as before: there's an initial slowdown, but then it gradually speeds up again. I'm not sure what the warning means, or whether it could be related.

RaulPPelaez commented 1 year ago

@peastman, this warning is usual with pytorch jit script models. The first runs through the model pytorch measures performance and looks for optimization opportunities. It is thus expected that the first few iterations are slower. One of the optimizations is transpiling with nvfuser, which tries to group compatible kernels into a single one. https://pytorch.org/blog/introducing-nvfuser-a-deep-learning-compiler-for-pytorch/ This sounds really cool, but in practice, at least with the now oldish version of pytorch in conda-forge, it fails a lot.

Just for completeness, this is a plot of the performance you posted, @peastman. Which is consistent with what I see on a 4090.

image

wiederm commented 1 year ago

Thank you for looking into this.

I have rebuild my conda environment and applied the patch described here: https://github.com/openmm/openmm-ml/issues/50#issuecomment-1495879938. Now I see similar performance and performance behavior as reported by @peastman and @RaulPPelaez for a waterbox with 20A edge size — no noticeable performance degradation for the simulation time (~50 ps).

However, increasing the box size to 30 A, as in the initially provided simulation script, results in performance degradation as the simulation progresses.

image

sef43 commented 1 year ago

is this running with NVFuser enabled or disabled?

Can you try with NVFuser disabled ?

torch._C._jit_set_nvfuser_enabled(False)

NVFuser is now NOT the default in newer pytorch: https://discuss.pytorch.org/t/nvfuser-with-torch-compile/187829/2

wiederm commented 1 year ago

Disabling NVFuser solved the issue! That's great, thank you!