mir-group / allegro

Allegro is an open-source code for building highly scalable and accurate equivariant deep learning interatomic potentials
https://www.nature.com/articles/s41467-023-36329-y
MIT License
347 stars 45 forks source link

Training dataset is wrongly thought to only have two datapoints #79

Open jonathan-booth opened 8 months ago

jonathan-booth commented 8 months ago

This is related to this discussion:

https://github.com/mir-group/allegro/discussions/74

As advised there, I used ase.io.read and ase.io.write to make my .extxyz training set our of ase atoms, but when I run nequip-train I get the following error:

Traceback (most recent call last):
  File ".../nequip-train", line 8, in <module>
    sys.exit(main())
  File ".../nequip/scripts/train.py", line 72, in main
    trainer = fresh_start(config)
  File ".../nequip/scripts/train.py", line 160, in fresh_start
    trainer.set_dataset(dataset, validation_dataset)
  File ".../nequip/train/trainer.py", line 1164, in set_dataset
    raise ValueError(
ValueError: too little data for training and validation. please reduce n_train and n_val

I modified nequip/train/trainer.py so that it would print out total_n which is the number of datapoints. This comes out as 2. However, my dataset has 249 datapoints in it. Here are the first 2 as an example:

6
Properties=species:S:1:charge:R:1:pos:R:3:forces:R:3 energy=3.3568491152989464E+00 pbc="F F F"
C    2.8860151849165884E-02   -1.9388596926913013E-02   -2.8759623532924683E-03   -2.6833257705618396E+00   -5.0188550432362087E+00   -7.8061520280432894E+00 
C   -2.1500824022609995E-02    3.6524386361154572E-03   -1.9698357525108050E-03    2.3421032349874338E+00   -6.6902716186377642E+00   -1.7849302559666409E+00 
C    2.3387800610515291E-02    3.8019240081639564E-03   -3.6286631993187970E-03   -3.5688662248506908E+00   -8.1820582146951519E+00    2.9663094116194868E+00 
C    1.3905598298491445E-02    7.5721377048330769E-03    8.0048531530292778E-03    9.5122339087921475E+00   -9.7041169051925280E-01    1.3042310895041316E+00 
C   -4.5729982211133953E-02    3.5546925801004338E-03    2.9333833145097753E-03   -1.2541717017776133E-01   -8.2066513548131166E+00    8.1008209468107211E+00 
C    1.0772554755713285E-03    8.0740399770008677E-04   -2.4637751624169836E-03   -4.5859471646614072E+00    3.5845739562088865E+00    8.2385401713477435E+00 
6
Properties=species:S:1:charge:R:1:pos:R:3:forces:R:3 energy=3.4898710368025143E+00 pbc="F F F"
C    7.6319141601180404E-03    8.9401622864701772E-03   -1.7921386926762573E-03    5.1772637533841459E-01   -2.7859567151648097E+00   -2.8128901836587339E-01 
C   -7.1088268866891773E-03    2.4796940863442618E-03   -2.0363664484700790E-02    8.5094886685353437E+00    3.2758187190546590E+00   -4.6838215430390271E+00 
C    1.4622996296095836E-02   -2.3619554337368280E-02   -6.2698372277412611E-03    9.1124305059183399E+00   -4.1808210366970284E+00    4.2946053591808280E+00 
C   -1.3909049200163072E-02    3.1250438952696268E-04   -2.4999557169415663E-03   -3.9259066958050122E+00   -6.5920997070775220E+00    6.8063951037962402E+00 
C   -2.4331619371849824E-02    6.8812024942755349E-03    3.4277307003635726E-03    2.0102930012740594E+00    4.1305435735101490E+00    8.8513121399612977E+00 
C    2.3094585002488198E-02    5.0059910807513461E-03    2.7497865421696303E-02    8.3670043621588555E+00    2.4338406515288975E+00    8.2732601005813198E+00 

Here is the relevant part of my config yaml:

dataset: ase
dataset_file_name: training.extxyz
ase_args:
 format: extxyz

chemical_symbols:
  - C

# training
n_train: 200
n_val: 49
batch_size: 1
max_epochs: 1
learning_rate: 0.002

What's wrong with my input? I have 249 structures but it thinks there are 2. I've checked all the docs for ase, Nequip and Allegro and it looks OK. This case has no PBC but when I run it with lattice parameters (physically valid ones) and pbc="T T T" it still gives the same error. I'm sure I've made an error somewhere but I can't find it.

Thanks in advance.

DavidW99 commented 8 months ago

I would suggest you to try iterating through the structures in your xyz file using ase. Then, you can check how many structures you get there. Also, you can check each frame in your trajectory.

For instance,

path = 'my_file.xyz'
trj = read(path, format='extxyz', index=':')
print(len(trj))
for frame in trj:
    print(frame.info)

Meanwhile, I notice your xyz has charge:R:1, which is not in your data. You need to delete that part.

jonathan-booth commented 8 months ago

I did that test and it says that both my .xyz that I read in to create the .extxyz, and the .extxyz, have four structures. Here is the .xyz I used:

6
test structure
C    2.8860151849165884E-02   -1.9388596926913013E-02   -2.8759623532924683E-03
C   -2.1500824022609995E-02    3.6524386361154572E-03   -1.9698357525108050E-03
C    2.3387800610515291E-02    3.8019240081639564E-03   -3.6286631993187970E-03
C    1.3905598298491445E-02    7.5721377048330769E-03    8.0048531530292778E-03
C   -4.5729982211133953E-02    3.5546925801004338E-03    2.9333833145097753E-03
C    1.0772554755713285E-03    8.0740399770008677E-04   -2.4637751624169836E-03
6
test structure
C    7.6319141601180404E-03    8.9401622864701772E-03   -1.7921386926762573E-03
C   -7.1088268866891773E-03    2.4796940863442618E-03   -2.0363664484700790E-02
C    1.4622996296095836E-02   -2.3619554337368280E-02   -6.2698372277412611E-03
C   -1.3909049200163072E-02    3.1250438952696268E-04   -2.4999557169415663E-03
C   -2.4331619371849824E-02    6.8812024942755349E-03    3.4277307003635726E-03
C    2.3094585002488198E-02    5.0059910807513461E-03    2.7497865421696303E-02
6
test structure
C    1.2387412163608045E-03   -8.5574951298672333E-03    7.7241088377932293E-03
C    5.9518605914250034E-03   -8.2517216861367447E-03   -2.1049058796149656E-02
C   -4.4181294087426055E-03    1.1257915490579735E-02    2.0296613820927602E-02
C    2.2750592516949414E-02    1.9876254153660830E-02   -2.2428474763927911E-02
C   -1.4535718330663154E-02   -9.2286264017176246E-03    2.4743701232480268E-03
C   -1.0987346585329461E-02   -5.0963264265189564E-03    1.2982440778108711E-02
6
test structure
C    1.5051670917489381E-02    4.6245535853670578E-03    7.9809236441371108E-03
C   -9.9796348997605117E-02   -1.3272694438123877E-02   -1.0976961125669295E-02
C    6.4515734398151017E-02   -5.1139630075321005E-02   -9.1306982503531042E-03
C   -6.3198467069126185E-02    3.5143227790004047E-02    1.1880960347834320E-02
C    9.1446738601495950E-02   -2.8256736429113386E-03    6.2237372171349309E-02
C   -8.0193278504050412E-03    2.7470216780985116E-02   -6.1991596787298341E-02

And here is the output of the code you gave me above to count and check frames:

4
{'test': True, 'structure': True}
{'test': True, 'structure': True}
{'test': True, 'structure': True}
{'test': True, 'structure': True}

Here is my code for reading the .xyz file and making the .extxyz:

from ase.io import read, write
import numpy as np
import sys

in_filename = 'test.xyz'
out_filename = 'training.extxyz'

test_forces=np.array(ReadForces()) # calls a function to read forces from the .txt files generated by my simulation

# read all frames into a list of ase.Atoms objects
frames = read(in_filename, format='xyz', index=':')

for i in range(len(frames)):
    frames[i].set_pbc(False)
    frames[i].new_array('forces',test_forces[i])

write(out_filename, frames, append=False, format='extxyz')

And here is the .extxyz file made by the above code. I still get the error where Nequip thinks there are only two datapoints.

6
Properties=species:S:1:pos:R:3:forces:R:3 pbc="F F F"
C        0.02886015      -0.01938860      -0.00287596      -2.68332577      -5.01885504      -7.80615203
C       -0.02150082       0.00365244      -0.00196984       2.34210323      -6.69027162      -1.78493026
C        0.02338780       0.00380192      -0.00362866      -3.56886622      -8.18205821       2.96630941
C        0.01390560       0.00757214       0.00800485       9.51223391      -0.97041169       1.30423109
C       -0.04572998       0.00355469       0.00293338      -0.12541717      -8.20665135       8.10082095
C        0.00107726       0.00080740      -0.00246378      -4.58594716       3.58457396       8.23854017
6
Properties=species:S:1:pos:R:3:forces:R:3 pbc="F F F"
C        0.00763191       0.00894016      -0.00179214       0.51772638      -2.78595672      -0.28128902
C       -0.00710883       0.00247969      -0.02036366       8.50948867       3.27581872      -4.68382154
C        0.01462300      -0.02361955      -0.00626984       9.11243051      -4.18082104       4.29460536
C       -0.01390905       0.00031250      -0.00249996      -3.92590670      -6.59209971       6.80639510
C       -0.02433162       0.00688120       0.00342773       2.01029300       4.13054357       8.85131214
C        0.02309459       0.00500599       0.02749787       8.36700436       2.43384065       8.27326010
6
Properties=species:S:1:pos:R:3:forces:R:3 pbc="F F F"
C        0.00123874      -0.00855750       0.00772411      -4.46553474      -2.48566747      -5.61580909
C        0.00595186      -0.00825172      -0.02104906      -6.64210137       8.01154524      -3.33580852
C       -0.00441813       0.01125792       0.02029661      -6.43155815      -9.67839883       3.46007324
C        0.02275059       0.01987625      -0.02242847       4.22526203      -0.91577443       8.91407516
C       -0.01453572      -0.00922863       0.00247437       9.85511079      -0.54486112       5.88790805
C       -0.01098735      -0.00509633       0.01298244       3.74308385       4.91803739       5.23557311
6
Properties=species:S:1:pos:R:3:forces:R:3 pbc="F F F"
C        0.01505167       0.00462455       0.00798092       9.05862076      -3.86747089      -6.10817828
C       -0.09979635      -0.01327269      -0.01097696      -4.63311512      -4.34674459      -0.75442826
C        0.06451573      -0.05113963      -0.00913070      -0.67961307       8.80642408      -3.61365465
C       -0.06319847       0.03514323       0.01188096       2.91468987       6.48905660      -3.73199774
C        0.09144674      -0.00282567       0.06223737      -1.33422343      -4.06850953       0.16283979
C       -0.00801933       0.02747022      -0.06199160      -2.23673595      -6.95618745       4.84691031

And here is the output of the frame counting code you gave me above:

4
{}
{}
{}
{}
DavidW99 commented 8 months ago

Thanks for checking this! I just ran a training with your 4-structure xyz file. After adding the energy tag for each structure, allegro can train without the above error. Could you try creating a separate environment and installing the NequIP from the latest develop branch and install allegro?

jonathan-booth commented 8 months ago

No problem. It works when I do that, thanks for all your help! Just in case anyone else has this problem in the future, I did the following:

Fix the .extxyz so it has energy (originally I still had this problem when the energy tag was there)

Make a new environment.

conda install pytorch==1.10 pip install wandb

git clone https://github.com/mir-group/nequip.git cd nequip pip install .

git clone https://github.com/mir-group/allegro.git cd allegro pip install .

My environment looks like this:

appdirs                   1.4.4                    pypi_0    pypi
ase                       3.22.1                   pypi_0    pypi
ca-certificates           2023.12.12           hca03da5_0  
certifi                   2024.2.2                 pypi_0    pypi
cffi                      1.16.0           py39h80987f9_0  
charset-normalizer        3.3.2                    pypi_0    pypi
click                     8.1.7                    pypi_0    pypi
contourpy                 1.2.0                    pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
docker-pycreds            0.4.0                    pypi_0    pypi
e3nn                      0.5.1                    pypi_0    pypi
exceptiongroup            1.2.0                    pypi_0    pypi
fonttools                 4.49.0                   pypi_0    pypi
future                    0.18.3           py39hca03da5_0  
gitdb                     4.0.11                   pypi_0    pypi
gitpython                 3.1.42                   pypi_0    pypi
idna                      3.6                      pypi_0    pypi
importlib-resources       6.1.3                    pypi_0    pypi
iniconfig                 2.0.0                    pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
libblas                   3.9.0           16_osxarm64_openblas    conda-forge
libcblas                  3.9.0           16_osxarm64_openblas    conda-forge
libcxx                    14.0.6               h848a8c0_0  
libffi                    3.4.4                hca03da5_0  
libgfortran               5.0.0           11_3_0_hca03da5_28  
libgfortran5              11.3.0              h009349e_28  
liblapack                 3.9.0           16_osxarm64_openblas    conda-forge
libopenblas               0.3.21               h269037a_0  
libprotobuf               3.19.6               h514c7bf_0  
llvm-openmp               14.0.6               hc6e5704_0  
matplotlib                3.8.3                    pypi_0    pypi
mir-allegro               0.2.0                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
ncurses                   6.4                  h313beb8_0  
nequip                    0.5.6                    pypi_0    pypi
ninja                     1.10.2               hca03da5_5  
ninja-base                1.10.2               h525c30c_5  
numpy                     1.26.4                   pypi_0    pypi
openssl                   3.0.13               h1a28f6b_0  
opt-einsum                3.3.0                    pypi_0    pypi
opt-einsum-fx             0.1.4                    pypi_0    pypi
packaging                 23.2                     pypi_0    pypi
pillow                    10.2.0                   pypi_0    pypi
pip                       23.3.1           py39hca03da5_0  
pluggy                    1.4.0                    pypi_0    pypi
protobuf                  4.25.3                   pypi_0    pypi
psutil                    5.9.8                    pypi_0    pypi
pycparser                 2.21               pyhd3eb1b0_0  
pyparsing                 3.1.2                    pypi_0    pypi
pytest                    8.0.2                    pypi_0    pypi
python                    3.9.7           hc0da0df_3_cpython    conda-forge
python-dateutil           2.9.0.post0              pypi_0    pypi
python_abi                3.9                      4_cp39    conda-forge
pytorch                   1.10.0          cpu_py39hbfdb42d_1    conda-forge
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h1a28f6b_0  
requests                  2.31.0                   pypi_0    pypi
scipy                     1.12.0                   pypi_0    pypi
sentry-sdk                1.41.0                   pypi_0    pypi
setproctitle              1.3.3                    pypi_0    pypi
setuptools                68.2.2           py39hca03da5_0  
six                       1.16.0                   pypi_0    pypi
sleef                     3.5.1                h80987f9_2  
smmap                     5.0.1                    pypi_0    pypi
sqlite                    3.41.2               h80987f9_0  
sympy                     1.12                     pypi_0    pypi
tk                        8.6.12               hb8d0fd4_0  
tomli                     2.0.1                    pypi_0    pypi
torch-ema                 0.3                      pypi_0    pypi
torch-runstats            0.2.0                    pypi_0    pypi
tqdm                      4.66.2                   pypi_0    pypi
typing_extensions         4.9.0            py39hca03da5_1  
tzdata                    2024a                h04d1e81_0  
urllib3                   2.2.1                    pypi_0    pypi
wandb                     0.16.4                   pypi_0    pypi
wheel                     0.41.2           py39hca03da5_0  
xz                        5.4.6                h80987f9_0  
zipp                      3.17.0                   pypi_0    pypi
zlib                      1.2.13               h5a0b063_0