torchmd / torchmd-net

Training neural network potentials
MIT License
332 stars 74 forks source link

Pre-trained model #19

Closed raimis closed 2 years ago

raimis commented 3 years ago

We are writing a paper about NNP/MM in ACEMD. So far, we have used ANI-2x for protein-ligand simulations, but to demonstrate a general utility, it would be good to include one more NNP.

Would it be possible to have a pre-trained TorchMD-NET model?

giadefa commented 3 years ago

on some molecules of MD17 for instance

On Mon, May 17, 2021 at 5:25 PM Raimondas Galvelis @.***> wrote:

We are writing a paper about NNP/MM in ACEMD. So far, we have used ANI-2x for protein-ligand simulations, but to demonstrate a general utility, it would be good to include one more NNP.

Would it be possible to have a pre-trained TorchMD-NET model?

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/compsciencelab/torchmd-net/issues/19, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOX7XRSRCVN33J2OQYDTOEYO7ANCNFSM45AV4DOQ .

PhilippThoelke commented 3 years ago

So you need a checkpoint file for a graph network trained e.g. on aspirin from the MD17 dataset? Would it work for you if the model is trained with the next version of the code, e.e. when #20 is merged? This change adds some features that improve the performance on MD17.

giadefa commented 3 years ago

yes

On Tue, May 18, 2021 at 2:36 PM Philipp Thölke @.***> wrote:

So you need a checkpoint file for a graph network trained e.g. on aspirin from the MD17 dataset? Would it work for you if the model is trained with the next version of the code, e.e. when #20 https://github.com/compsciencelab/torchmd-net/pull/20 is merged? This change adds some features that improve the performance on MD17.

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/compsciencelab/torchmd-net/issues/19#issuecomment-843133654, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOVGFGNUGCBCJKL4PSDTOJNLTANCNFSM45AV4DOQ .

PhilippThoelke commented 3 years ago

I have added a graph network checkpoint pretrained on aspirin from the MD17 dataset. It used 950 samples for training, 50 for validation and the remaining samples for testing, which is the standard benchmark for this dataset. I used energies and forces for training, the exact hyperparameters can be found in the hparams.yaml file. You can find the model checkpoint, hyperparameters and splits here.

I also included the metrics.csv, which contains losses and learning rate for each epoch during training. The model checkpoint comes from epoch 1269 and reached an MAE of 0.224 for the energy and 0.630 for the forces on the test set.

The model was trained on version 0.1.0.

raimis commented 3 years ago

The pre-trained model cannot be loaded:

import torch
from torchmdnet.models import load_model

model = load_model('examples/pretrained/md17-aspirin-graph-network/epoch=1269-val_loss=0.8859-test_loss=0.5893.ckpt')
Traceback (most recent call last):
  File "tmn_load.py", line 10, in <module>
    model = load_model('examples/pretrained/md17-aspirin-graph-network/epoch=1269-val_loss=0.8859-test_loss=0.5893.ckpt')
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torchmdnet/models/model.py", line 78, in load_model
    model = create_model(args)
  File "/shared/raimis/opt/miniconda/envs/tmn/lib/python3.8/site-packages/torchmdnet/models/model.py", line 27, in create_model
    max_num_neighbors=args['max_num_neighbors'],
KeyError: 'max_num_neighbors'
PhilippThoelke commented 3 years ago

It should be possible to load it using version 0.1.x under which it was trained. Since then I didn't update the model but I can do that now. I'll have to retrain it using the current version which will take half a day roughly.

raimis commented 3 years ago

Thanks! It will be the most useful, if I can test the simulations with the latest version.

PhilippThoelke commented 3 years ago

Nevermind, I still had a recent model checkpoint from the most recent version, I just pushed it.

raimis commented 3 years ago

Thanks! Now it works.

raimis commented 3 years ago

I tried to run MD simulations with the NNP model:

In both cases, the simulations "explode" with 1 ps. I tried to reduce the timestep to 0.5 fs, but it doesn't help. The same simulations with ANI-2x run without problems.

raimis commented 3 years ago

@PhilippThoelke would it be possible to run the system with TorchMD to verify the problem?

PhilippThoelke commented 3 years ago

Yes that is possible. You can use torchmdnet.calculators.External as an external force inside TorchMD.

raimis commented 3 years ago

@stefdoerr could you help to step up the simulations?

stefdoerr commented 3 years ago

Don't you already have the input files since you ran them?

raimis commented 3 years ago

@stefdoerr I do have input files (PDB and PRMTOP), but I haven't used TorchMD.

stefdoerr commented 3 years ago

https://github.com/torchmd/torchmd/blob/master/examples/tutorial.ipynb It's relatively simple, but if you don't want to try it send me the input files and I can take a look

raimis commented 3 years ago

Where do I need to add torchmdnet.calculators.External?

PhilippThoelke commented 3 years ago

You have to enter that as the --external arg in run.py: https://github.com/torchmd/torchmd/blob/3e12a6858c603af6c2b76696ff4edf0956dc0ea5/torchmd/run.py#L50 It additionally requires the path of the model checkpoint and embedding indices in this arg. I'm not sure if it's possible to pass all of that via the command line or whether you have to use a yaml config file for that. I believe you would have to enter it as

external:
  module: torchmdnet.calculators.External
  embeddings: [ 1, 1, 6, 6, ...]
  file: path/to/checkpoint
raimis commented 3 years ago

I tried a simulation of aspirin with TorchMD:

coordinates: aspirin.pdb
cutoff: null
device: cuda
extended_system: null
external:
  embeddings:
  - 8
  - 8
  - 8
  - 8
  - 6
  - 6
  - 6
  - 6
  - 6
  - 6
  - 6
  - 6
  - 6
  - 1
  - 1
  - 1
  - 1
  - 1
  - 1
  - 1
  - 1
  file: ../torchmd-net.git/examples/pretrained/md17-aspirin-graph-network/epoch=1359-val_loss=0.5227-test_loss=0.4333.ckpt
  module: torchmdnet.calculators
forcefield: aspirin.prmtop
forceterms: null
langevin_gamma: 0.1
langevin_temperature: 300
log_dir: ./
minimize: 100
output: output
output_period: 1
precision: single
replicas: 1
rfa: false
save_period: 1
seed: 1
steps: 100
structure: null
switch_dist: null
temperature: 300
timestep: 1
topology: aspirin.prmtop

All the input files: aspirin_torchmd.zip

The simulation is unstable: the temperature resize uncontrollably from the first steps.

iter,ns,epot,ekin,etot,T,t
1,1e-06,-406880.0887773633,15.103168487548828,-406864.98560887575,241.27809871894112,4.027363538742065
2,2e-06,-406864.92625647783,14.164608001708984,-406850.7616484761,226.28428535170937,4.054379224777222
3,3e-06,-406861.4632962942,14.159696578979492,-406847.30359971523,226.20582375345904,4.08146595954895
4,4e-06,-406852.1209000349,16.686145782470703,-406835.43475425243,266.566683187086,4.108911752700806
5,4.9999999999999996e-06,-406845.173047781,35.16876983642578,-406810.00427794456,561.8326993711507,4.13638710975647
6,6e-06,-406819.02439904213,19.055503845214844,-406799.9688951969,304.417959827123,4.16302752494812
7,7e-06,-406835.88994038105,36.33468246459961,-406799.55525791645,580.4585382095441,4.190194845199585
8,8e-06,-406816.4535654783,69.41360473632812,-406747.039960742,1108.905233349976,4.2174201011657715
9,9e-06,-406790.12160658836,100.08391571044922,-406690.0376908779,1598.87356847484,4.24408483505249
10,9.999999999999999e-06,-406718.0202502012,50.58804702758789,-406667.43220317364,808.1607389061002,4.270519733428955
11,1.1e-05,-406711.58965563774,57.65319061279297,-406653.93646502495,921.0287382811944,4.2969443798065186
12,1.2e-05,-406745.146247983,97.73880767822266,-406647.40744030476,1561.4096940717563,4.3239805698394775
13,1.3e-05,-406780.00586628914,137.84088134765625,-406642.1649849415,2202.053549539873,4.350756406784058
14,1.4e-05,-406737.1390030384,140.08062744140625,-406597.058375597,2237.8342322197154,4.377088785171509
15,1.4999999999999999e-05,-406667.7502512932,89.35147857666016,-406578.3987727165,1427.4193449192999,4.403895616531372
16,1.6e-05,-406643.7843030691,93.51447296142578,-406550.2698301077,1493.924553476163,4.430569410324097
17,1.7e-05,-406690.91243588924,150.41690063476562,-406540.4955352545,2402.959606143029,4.457834005355835
18,1.8e-05,-406717.71087527275,181.7862091064453,-406535.9246661663,2904.094656871925,4.500119924545288
19,1.8999999999999998e-05,-406734.17486667633,194.3318328857422,-406539.8430337906,3104.5151352111134,4.535091400146484
20,1.9999999999999998e-05,-406659.9773603678,175.61720275878906,-406484.360157609,2805.5427454783216,4.569720983505249
21,2.1e-05,-406585.90433883667,176.90951538085938,-406408.9948234558,2826.1878659151803,4.612438201904297
22,2.2e-05,-406574.9152857065,188.20639038085938,-406386.70889532566,3006.659170576359,4.6477577686309814
23,2.3e-05,-406620.7030599117,219.14315795898438,-406401.55990195274,3500.884237847075,4.679394721984863
24,2.4e-05,-406645.82324945927,184.114501953125,-406461.70874750614,2941.289903139019,4.707014799118042
25,2.4999999999999998e-05,-406701.47101426125,153.501708984375,-406547.96930527687,2452.2404371236057,4.734936475753784
26,2.6e-05,-406696.4496754408,133.06582641601562,-406563.3838490248,2125.770471844318,4.761986255645752
27,2.7e-05,-406683.0426847935,122.48103332519531,-406560.5616514683,1956.6749105790145,4.788716793060303
28,2.8e-05,-406666.3019833565,107.2762451171875,-406559.0257382393,1713.77340330412,4.815651178359985
29,2.9e-05,-406697.5885055065,154.06419372558594,-406543.52431178093,2461.2263164130854,4.843580484390259
30,2.9999999999999997e-05,-406699.77242171764,160.65704345703125,-406539.1153782606,2566.549265677288,4.872596979141235
31,3.1e-05,-406650.5109888315,117.29191589355469,-406533.21907293797,1873.7770478578418,4.89995551109314
32,3.2e-05,-406690.5119087696,119.58175659179688,-406570.9301521778,1910.3580083693134,4.927629232406616
33,3.2999999999999996e-05,-406737.56370294094,149.2154998779297,-406588.348203063,2383.7668327426763,4.95459771156311
34,3.4e-05,-406735.5253405571,144.68548583984375,-406590.83985471725,2311.3983641540776,4.980782508850098
35,3.5e-05,-406708.95912611485,159.289794921875,-406549.66933119297,2544.707019309533,5.007444620132446
36,3.6e-05,-406693.1858738661,140.61151123046875,-406552.5743626356,2246.315275874317,5.037170886993408
37,3.7e-05,-406639.2426587343,125.33792877197266,-406513.90472996235,2002.3147577544935,5.06561803817749
38,3.7999999999999995e-05,-406632.5847103596,132.43565368652344,-406500.14905667305,2115.7032546135924,5.09315037727356
39,3.9e-05,-406643.4117741585,185.24386596679688,-406458.1679081917,2959.3318658043354,5.120187997817993
40,3.9999999999999996e-05,-406671.7958230972,248.03517150878906,-406423.76065158844,3962.443685005843,5.147157669067383
41,4.1e-05,-406553.23693335056,179.92616271972656,-406373.31077063084,2874.3798022646606,5.175485134124756
42,4.2e-05,-406463.92002630234,114.83243560791016,-406349.0875906944,1834.486038978066,5.201632022857666
43,4.2999999999999995e-05,-406518.50567913055,158.01731872558594,-406360.48836040497,2524.378792318035,5.228895425796509
44,4.4e-05,-406609.36943364143,244.32044982910156,-406365.04898381233,3903.0997807857493,5.256029367446899
45,4.4999999999999996e-05,-406683.1779911518,327.4761047363281,-406355.7018864155,5231.538798749737,5.285285234451294
46,4.6e-05,-406645.70803165436,328.3685607910156,-406317.33947086334,5245.796078620695,5.312484264373779
47,4.7e-05,-406546.0222530365,272.8658447265625,-406273.15640830994,4359.121880633125,5.339726686477661
48,4.8e-05,-406524.3602576256,271.5589294433594,-406252.8013281822,4338.243477867644,5.367532968521118
49,4.9e-05,-406589.9699988365,378.48468017578125,-406211.48531866074,6046.417617756432,5.394550561904907
50,4.9999999999999996e-05,-406630.7346057892,405.6400146484375,-406225.09459114075,6480.233043773889,5.421330451965332
51,5.1e-05,-406475.4479403496,313.1947326660156,-406162.25320768356,5003.389168884751,5.448347568511963
52,5.2e-05,-406316.43815875053,332.8667907714844,-405983.57136797905,5317.65678640415,5.476334810256958
53,5.3e-05,-406198.64964079857,332.2574768066406,-405866.3921639919,5307.92279481943,5.5077526569366455
54,5.4e-05,-406406.16769218445,666.53564453125,-405739.6320476532,10648.126793625168,5.534675359725952
55,5.4999999999999995e-05,-406424.89574217796,873.1506958007812,-405551.7450463772,13948.870394421729,5.561588525772095
56,5.6e-05,-406168.2948439121,687.5379028320312,-405480.7569410801,10983.644798062065,5.588784694671631
57,5.6999999999999996e-05,-405880.53273034096,471.45465087890625,-405409.07807946205,7531.643568039544,5.615290641784668
58,5.8e-05,-405950.1033626795,561.5723876953125,-405388.53097498417,8971.304141106837,5.642324209213257
59,5.9e-05,-406326.1232010126,890.3369140625,-405435.7862869501,14223.425900425305,5.66890287399292
60,5.9999999999999995e-05,-406499.3240991831,944.90673828125,-405554.41736090183,15095.196843441758,5.69551157951355
61,6.1e-05,-406246.7270579338,662.7271728515625,-405583.99988508224,10587.285202229465,5.7228147983551025
62,6.2e-05,-405939.76182341576,383.782470703125,-405555.97935271263,6131.051569029803,5.749572992324829
63,6.3e-05,-405940.30504751205,388.75714111328125,-405551.5479063988,6210.52356984244,5.77661657333374
64,6.4e-05,-406131.46995961666,551.8641357421875,-405579.6058238745,8816.211613663636,5.803932428359985
65,6.5e-05,-406212.5055809021,601.13427734375,-405611.37130355835,9603.318378647316,5.8413918018341064
66,6.599999999999999e-05,-406231.4842660427,620.8232421875,-405610.6610238552,9917.856086887838,5.8763813972473145
67,6.7e-05,-406165.6405694485,587.7647705078125,-405577.87579894066,9389.736096701363,5.913886547088623
68,6.8e-05,-406179.55393338203,581.9989013671875,-405597.55503201485,9297.624435174239,5.952528476715088
69,6.9e-05,-406209.120762825,632.543701171875,-405576.57706165314,10105.094285428406,5.990658521652222
70,7e-05,-406223.5871543884,640.5198974609375,-405583.0672569275,10232.516652279484,6.025377988815308
71,7.099999999999999e-05,-406243.78583192825,633.56640625,-405610.21942567825,10121.432336414438,6.0626561641693115
72,7.2e-05,-406195.6852698326,616.900390625,-405578.7848792076,9855.187239133347,6.09675407409668
73,7.3e-05,-406080.0879020691,685.86767578125,-405394.22022628784,10956.962369962712,6.13096809387207
74,7.4e-05,-405992.6495513916,722.2652587890625,-405270.38429260254,11538.425765099284,6.16308069229126
75,7.5e-05,-405902.5705215931,666.3377075195312,-405236.23281407356,10644.964684568848,6.1934545040130615
76,7.599999999999999e-05,-405925.06811475754,673.291015625,-405251.77709913254,10756.046075263628,6.223395109176636
77,7.7e-05,-406178.0604672432,936.1538696289062,-405241.9065976143,14955.366879383875,6.253476858139038
78,7.8e-05,-406207.9582891464,1001.4998779296875,-405206.45841121674,15999.290917884982,6.283278226852417
79,7.9e-05,-406096.5415635109,953.51416015625,-405143.02740335464,15232.703247252584,6.313279867172241
80,7.999999999999999e-05,-406192.39066147804,1030.43701171875,-405161.9536497593,16461.57117574985,6.343334436416626
81,8.099999999999999e-05,-406306.50572681427,1111.22900390625,-405195.276722908,17752.249902057247,6.3734519481658936
82,8.2e-05,-406053.14411354065,925.3079833984375,-405127.8361301422,14782.10026908506,6.403440713882446
83,8.3e-05,-405683.14057159424,582.029052734375,-405101.11151885986,9298.106113211188,6.4332404136657715
84,8.4e-05,-405481.014734745,430.88470458984375,-405050.1300301552,6883.525293134223,6.4636194705963135
85,8.499999999999999e-05,-405629.4579527378,558.5642700195312,-405070.8936827183,8923.248468938862,6.494351148605347
86,8.599999999999999e-05,-405920.31214261055,825.638671875,-405094.67347073555,13189.850139264538,6.524912118911743
87,8.7e-05,-406069.9568309784,1007.77734375,-405062.1794872284,16099.575505132141,6.554839134216309
88,8.8e-05,-406112.59860801697,1110.46728515625,-405002.1313228607,17740.081193755406,6.585700750350952
89,8.9e-05,-406089.90273809433,1081.825927734375,-405008.07681035995,17282.52606092505,6.615431308746338
90,8.999999999999999e-05,-405868.9691205025,820.49609375,-405048.4730267525,13107.69575731902,6.645550966262817
91,9.099999999999999e-05,-405675.4124674797,607.3097534179688,-405068.10271406174,9701.973646056946,6.675546169281006
92,9.2e-05,-405572.9032449722,525.2464599609375,-405047.6567850113,8390.985462600582,6.705816984176636
93,9.3e-05,-405769.4916372299,714.1212158203125,-405055.3704214096,11408.32199217208,6.737977981567383
94,9.4e-05,-406103.68017959595,1042.67578125,-405061.00439834595,16657.0895562535,6.768180847167969
95,9.499999999999999e-05,-406297.6377902031,1278.121826171875,-405019.5159640312,20418.41779121806,6.7977800369262695
96,9.6e-05,-406258.2769627571,1284.951904296875,-404973.32505846024,20527.530542324588,6.827901363372803
97,9.7e-05,-405986.5238389969,1074.440673828125,-404912.08316516876,17164.54419357545,6.857760190963745
98,9.8e-05,-405746.92120170593,938.0362548828125,-404808.8849468231,14985.438604763453,6.887917518615723
99,9.9e-05,-405668.9798822403,967.4277954101562,-404701.55208683014,15454.978160168965,6.917438983917236
100,9.999999999999999e-05,-405585.2989048958,924.967529296875,-404660.3313755989,14776.661402505919,6.9477856159210205
raimis commented 3 years ago

I observe the same with ACEMD. So, probably the problem is the pre-trained model or some bug in TorchMD-Net.

raimis commented 3 years ago

Visualised the trajectory, the molecule just explodes literally.

raimis commented 3 years ago

@PhilippThoelke would it be possible to train a less "explosive" model?

PhilippThoelke commented 3 years ago

I tried some small simulations myself and while visualizing I found that the simulation step before the "explosion" usually has two hydrogens that are very close to each other. In one of the runs I checked the distance, which turned out to be 0.11A. I then compared this to the minimum distance two atoms ever are in the MD17 aspirin dataset, which is 0.89A. The dataset I trained the model on might just be not very good for simulation. I can start training a model on the ANI dataset, which probably makes it easier to compare to the ANI model as well. This will however take some time as the ANI dataset is much larger.

raimis commented 3 years ago

@PhilippThoelke yes, I think that training with the ANI data is the easiest solution. Anyway, I don't need anything very accurate, just good enough that simulation stays stable and looks physical.

PhilippThoelke commented 3 years ago

I just merged a couple of changes into main, including two new model checkpoints from the ANI1 dataset. One is from a Transformer model and the other one is an equivariant Transformer checkpoint. The equivariant model currently only works with TorchScript on the PyTorch Geometric main branch as they had a bug that was only recently fixed, however, it has a lower loss than the Transformer checkpoint. So for testing I recommend using the Transformer checkpoint instead of the equivariant one so you don't have to install PyTorch Geometric from GitHub.

I tested simulating with both models using TorchMD and both are capable of simulating aspirin without "explosions".

Since the ANI1 dataset only includes energies and not forces, the model checkpoint has set the derivative flag to False. In TorchMD this does not change anything because the torchmdnet.calculators.External module overwrites the flag during loading. If you want to load the model yourself and enable force computation you will have to pass derivative=True to the load_model function.

raimis commented 3 years ago

I have tried to run aspirin with ANI1-transformer:

PhilippThoelke commented 3 years ago

I just pushed the most recent checkpoints from training on ANI1, which at least have better loss than the ones you tested. The models are also still training and haven't converged yet. It might also make sense to try the equivariant Transformer as it has better loss. There hasn't been a new torch-geometric release yet so the TorchScript fix is still only on their main branch.

Do you have any ideas why it might explode? What is the difference between the ACEMD MD simulation and ACEMD NNP/MD simulation?

What do you mean by OpenMM-Torch/PyTorch-Geometric incompatibility, how did you write the interface?

raimis commented 3 years ago

Thanks @PhilippThoelke, I'll try with the new model.

Do you have any ideas why it might explode? What is the difference between the ACEMD MD simulation and ACEMD NNP/MD simulation?

MD is just a molecule in vacuum. NNP/MM adds solvent at MM level.

What do you mean by OpenMM-Torch/PyTorch-Geometric incompatibility, how did you write the interface?

Current PyTorch Geometric packages are not compatible with conda-forge packages, so I had to rebuild PyG.

raimis commented 3 years ago

I have managed to run the latest checkpoint of ANI1-transformer with ACEMD on GPU.

Meanwhile ANI1-equivariant_transformer fails with the following error:

The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: The size of tensor a (128) must match the size of tensor b (384) at non-singleton dimension 2
raimis commented 3 years ago

For some reason that issues does not manifest outside of ACEMD.

giadefa commented 3 years ago

it seems that you are loading a different model, weights and model are different.

On Tue, Sep 28, 2021 at 5:19 PM Raimondas Galvelis @.***> wrote:

I have managed to run the latest checkpoint of ANI1-transformer with ACEMD on GPU.

  • The simulations of aspirin is stable after ~0.1 ns and keeps running
  • Speed ~10 ns/day on GTX 1080 Ti

Meanwhile ANI1-equivariant_transformer fails with the following error:

The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): RuntimeError: The size of tensor a (128) must match the size of tensor b (384) at non-singleton dimension 2

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/compsciencelab/torchmd-net/issues/19#issuecomment-929330983, or unsubscribe https://github.com/notifications/unsubscribe-auth/AB3KUOS3KE6WLMP754FXQ73UEHMJHANCNFSM45AV4DOQ . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.

PhilippThoelke commented 3 years ago

At what point does the error occur? During loading or when running inference? Could you maybe share the code snippet where the error occured?