txie-93 / cgcnn

Crystal graph convolutional neural networks for predicting material properties.
MIT License
651 stars 309 forks source link

Size mismatch when loading pre-trained models #37

Open Rich2333 opened 2 years ago

Rich2333 commented 2 years ago


When I try to load pre-trained models to test predict.py, I was noticed as follows:

python predict.py pre-trained/final-energy-per-atom.pth.tar mp/ => loading model params 'pre-trained/final-energy-per-atom.pth.tar' => loaded model params 'pre-trained/final-energy-per-atom.pth.tar' => loading model 'pre-trained/final-energy-per-atom.pth.tar' Traceback (most recent call last): File "E:\cgcnn-master\predict.py", line 298, in main() File "E:\cgcnn-master\predict.py", line 94, in main model.load_state_dict(checkpoint['state_dict']) File "C:\ProgramData\Anaconda3\envs\cgcnn1\lib\site-packages\torch\nn\modules\module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for CrystalGraphConvNet: size mismatch for convs.0.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]). size mismatch for convs.1.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]). size mismatch for convs.2.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]). size mismatch for convs.3.fc_full.weight: copying a param with shape torch.Size([128, 169]) from checkpoint, the shape in current model is torch.Size([128, 179]).

btw, then I tried to train my own model and use it to predict. The errors above didn't show up, but I got a TOO large MAE.

(cgcnn) E:\cgcnn-master>python predict.py E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar mp/ => loading model params 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar' => loaded model params 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar' => loading model 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar' => loaded model 'E:\cgcnn-master\trained_files\from_cmd\mp-2\mp_model_best.pth.tar' (epoch 484, validation 0.05862389877438545) C:\ProgramData\Anaconda3\envs\cgcnn\lib\site-packages\pymatgen\io\cif.py:1155: UserWarning: Issues encountered while parsing CIF: Some fractional coordinates rounded to ideal values to avoid issues with finite precision. warnings.warn("Issues encountered while parsing CIF: " + "\n".join(self.warnings)) Test: [0/74] Time 26.633 (26.633) Loss inf (inf) MAE 5.977 (5.977) Test: [10/74] Time 24.787 (27.052) Loss inf (inf) MAE 6.005 (6.013) Test: [20/74] Time 28.383 (28.096) Loss inf (inf) MAE 5.941 (6.010) Test: [30/74] Time 31.305 (28.518) Loss inf (inf) MAE 6.081 (6.008) Test: [40/74] Time 30.491 (29.037) Loss inf (inf) MAE 5.860 (6.010) Test: [50/74] Time 35.822 (29.651) Loss inf (inf) MAE 6.035 (6.008) Test: [60/74] Time 33.488 (30.191) Loss inf (inf) MAE 6.033 (6.012) Test: [70/74] Time 34.823 (30.565) Loss inf (inf) MAE 5.955 (6.008) ** MAE 6.009

Thanks for your attention!

txie-93 commented 2 years ago

Thanks for reaching out. Can you provide more details? What is your train/test data? What are other hyperparameters?

Rich2333 commented 2 years ago

My train/test data is crystal structures downloaded from Materials Project database (using MPRester API). In the case of using my own trained model, I use same data to train and predict. And I'm quite confused that the prediction MAE(e.g. 5.977) is 100 times more than the training MAE(e.g. 0.0586).

As for parameters, I think I didn't change any default ones other than epochs: python main.py --epochs 500 --train-ratio 0.6 --val-ratio 0.2 --test-ratio 0.2 mp/ python predict.py E:/cgcnn-master/trained_files/from_cmd/mp-2/mp_model_best.pth.tar mp/

Rich2333 commented 2 years ago

And for the size mismatch problem, I'm wondering if my environment is different from the environment of pre-trained models.

The packages in my vritual environment are listed as follows:

#Name                    Version                   Build  Channel
ase                       3.22.1             pyhd8ed1ab_1    conda-forge
blas                      2.114                       mkl    conda-forge
blas-devel                3.9.0              14_win64_mkl    conda-forge
brotli                    1.0.9                h8ffe710_7    conda-forge
brotli-bin                1.0.9                h8ffe710_7    conda-forge
brotlipy                  0.7.0           py310he2412df_1004    conda-forge
bzip2                     1.0.8                h8ffe710_4    conda-forge
ca-certificates           2021.10.8            h5b45459_0    conda-forge
certifi                   2021.10.8       py310h5588dad_2    conda-forge
cffi                      1.15.0          py310hcbf9ad4_0    conda-forge
cftime                    1.6.0           py310h2873277_1    conda-forge
charset-normalizer        2.0.12             pyhd8ed1ab_0    conda-forge
click                     8.1.3           py310h5588dad_0    conda-forge
colorama                  0.4.4              pyh9f0ad1d_0    conda-forge
cryptography              36.0.2          py310ha857299_1    conda-forge
cudatoolkit               11.3.1               h59b6b97_2
curl                      7.83.0               h789b8ee_0    conda-forge
cycler                    0.11.0             pyhd8ed1ab_0    conda-forge
cython                    0.29.28         py310h8a704f9_2    conda-forge
double-conversion         3.2.0                h0e60522_0    conda-forge
eigen                     3.4.0                h2d74725_0    conda-forge
expat                     2.4.8                h39d44d4_0    conda-forge
ffmpeg                    4.3.1                ha925a31_0    conda-forge
flask                     2.1.2              pyhd8ed1ab_1    conda-forge
fonttools                 4.33.3          py310he2412df_0    conda-forge
freetype                  2.10.4               h546665d_1    conda-forge
future                    0.18.2          py310h5588dad_5    conda-forge
gl2ps                     1.4.2                h0597ee9_0    conda-forge
glew                      2.1.0                h39d44d4_2    conda-forge
hdf4                      4.2.15               h0e5069d_3    conda-forge
hdf5                      1.12.1          nompi_h2a0e4a3_104    conda-forge
icu                       69.1                 h0e60522_0    conda-forge
idna                      3.3                pyhd8ed1ab_0    conda-forge
importlib-metadata        4.11.3          py310h5588dad_1    conda-forge
intel-openmp              2022.0.0          h57928b3_3663    conda-forge
itsdangerous              2.1.2              pyhd8ed1ab_0    conda-forge
jbig                      2.1               h8d14728_2003    conda-forge
jinja2                    3.1.2              pyhd8ed1ab_0    conda-forge
joblib                    1.1.0              pyhd8ed1ab_0    conda-forge
jpeg                      9e                   h8ffe710_1    conda-forge
jsoncpp                   1.9.5                h2d74725_1    conda-forge
kiwisolver                1.4.2           py310h476a331_1    conda-forge
krb5                      1.19.3               h1176d77_0    conda-forge
latexcodec                2.0.1              pyh9f0ad1d_0    conda-forge
lcms2                     2.12                 h2a16943_0    conda-forge
lerc                      3.0                  h0e60522_0    conda-forge
libblas                   3.9.0              14_win64_mkl    conda-forge
libbrotlicommon           1.0.9                h8ffe710_7    conda-forge
libbrotlidec              1.0.9                h8ffe710_7    conda-forge
libbrotlienc              1.0.9                h8ffe710_7    conda-forge
libcblas                  3.9.0              14_win64_mkl    conda-forge
libclang                  13.0.1          default_h81446c8_0    conda-forge
libcurl                   7.83.0               h789b8ee_0    conda-forge
libdeflate                1.10                 h8ffe710_0    conda-forge
libffi                    3.4.2                h8ffe710_5    conda-forge
libiconv                  1.16                 he774522_0    conda-forge
liblapack                 3.9.0              14_win64_mkl    conda-forge
liblapacke                3.9.0              14_win64_mkl    conda-forge
libnetcdf                 4.8.1           nompi_h1cc8e9d_102    conda-forge
libogg                    1.3.4                h8ffe710_1    conda-forge
libpng                    1.6.37               h1d00b33_2    conda-forge
libssh2                   1.10.0               h680486a_2    conda-forge
libtheora                 1.1.1             h8d14728_1005    conda-forge
libtiff                   4.3.0                hc4061b1_3    conda-forge
libuv                     1.43.0               h8ffe710_0    conda-forge
libwebp                   1.2.2                h57928b3_0    conda-forge
libwebp-base              1.2.2                h8ffe710_1    conda-forge
libxcb                    1.13              hcd874cb_1004    conda-forge
libxml2                   2.9.14               hf5bbc77_0    conda-forge
libzip                    1.8.0                hfed4ece_1    conda-forge
libzlib                   1.2.11            h8ffe710_1014    conda-forge
loguru                    0.6.0           py310h5588dad_1    conda-forge
lz4-c                     1.9.3                h8ffe710_1    conda-forge
m2w64-gcc-libgfortran     5.3.0                         6    conda-forge
m2w64-gcc-libs            5.3.0                         7    conda-forge
m2w64-gcc-libs-core       5.3.0                         7    conda-forge
m2w64-gmp                 6.1.0                         2    conda-forge
m2w64-libwinpthread-git               2    conda-forge
markupsafe                2.1.1           py310he2412df_1    conda-forge
matplotlib-base           3.5.2           py310h79a7439_0    conda-forge
mkl                       2022.0.0           h0e2418a_796    conda-forge
mkl-devel                 2022.0.0           h57928b3_797    conda-forge
mkl-include               2022.0.0           h0e2418a_796    conda-forge
monty                     2022.4.26          pyhd8ed1ab_0    conda-forge
mpmath                    1.2.1              pyhd8ed1ab_0    conda-forge
msys2-conda-epoch         20160418                      1    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
netcdf4                   1.5.8           nompi_py310h5489b47_101    conda-forge
networkx                  2.8                pyhd8ed1ab_0    conda-forge
numpy                     1.22.3          py310hed7ac4c_2    conda-forge
openjpeg                  2.4.0                hb211442_1    conda-forge
openssl                   1.1.1o               h8ffe710_0    conda-forge
packaging                 21.3               pyhd8ed1ab_0    conda-forge
palettable                3.3.0                      py_0    conda-forge
pandas                    1.4.2           py310hf5e1058_1    conda-forge
pillow                    9.1.0           py310h767b3fd_2    conda-forge
pip                       22.0.4             pyhd8ed1ab_0    conda-forge
plotly                    5.7.0              pyhd8ed1ab_0    conda-forge
proj                      9.0.0                h1cfcee9_1    conda-forge
pthread-stubs             0.4               hcd874cb_1001    conda-forge
pugixml                   1.11.4               h0e60522_0    conda-forge
pybtex                    0.24.0             pyhd8ed1ab_2    conda-forge
pycparser                 2.21               pyhd8ed1ab_0    conda-forge
pymatgen                  2022.4.26       py310h476a331_0    conda-forge
pyopenssl                 22.0.0             pyhd8ed1ab_0    conda-forge
pyparsing                 3.0.8              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1           py310h5588dad_5    conda-forge
python                    3.10.4          h9a09f29_0_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python_abi                3.10                    2_cp310    conda-forge
pytorch                   1.11.0          py3.10_cuda11.3_cudnn8_0    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pytz                      2022.1             pyhd8ed1ab_0    conda-forge
pyyaml                    6.0             py310he2412df_4    conda-forge
qt                        5.12.9               h556501e_6    conda-forge
requests                  2.27.1             pyhd8ed1ab_0    conda-forge
ruamel.yaml               0.17.21         py310he2412df_1    conda-forge
ruamel.yaml.clib          0.2.6           py310he2412df_1    conda-forge
scikit-learn              1.0.2           py310h4dafddf_0    conda-forge
scipy                     1.8.0           py310h33db832_1    conda-forge
setuptools                62.1.0          py310h5588dad_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
spglib                    1.16.4          py310h2873277_0    conda-forge
sqlite                    3.38.4               h8ffe710_0    conda-forge
sympy                     1.10.1          py310h5588dad_0    conda-forge
tabulate                  0.8.9              pyhd8ed1ab_0    conda-forge
tbb                       2021.5.0             h2d74725_1    conda-forge
tbb-devel                 2021.5.0             h2d74725_1    conda-forge
tenacity                  8.0.1              pyhd8ed1ab_0    conda-forge
threadpoolctl             3.1.0              pyh8a188c0_0    conda-forge
tk                        8.6.12               h8ffe710_0    conda-forge
torchvision               0.12.0              py310_cu113    pytorch
tqdm                      4.64.0             pyhd8ed1ab_0    conda-forge
typing_extensions         4.2.0              pyha770c72_1    conda-forge
tzdata                    2022a                h191b570_0    conda-forge
ucrt                      10.0.20348.0         h57928b3_0    conda-forge
uncertainties             3.1.6              pyhd8ed1ab_0    conda-forge
unicodedata2              14.0.0          py310he2412df_1    conda-forge
urllib3                   1.26.9             pyhd8ed1ab_0    conda-forge
utfcpp                    3.2.1                h57928b3_0    conda-forge
vc                        14.2                 hb210afc_6    conda-forge
vs2015_runtime            14.29.30037          h902a5da_6    conda-forge
vtk                       9.1.0           qt_py310h99a8838_207    conda-forge
werkzeug                  2.1.2              pyhd8ed1ab_1    conda-forge
wheel                     0.37.1             pyhd8ed1ab_0    conda-forge
win32_setctime            1.1.0              pyhd8ed1ab_0    conda-forge
win_inet_pton             1.1.0           py310h5588dad_4    conda-forge
xorg-libxau               1.0.9                hcd874cb_0    conda-forge
xorg-libxdmcp             1.1.3                hcd874cb_0    conda-forge
xz                        5.2.5                h62dcd97_1    conda-forge
yaml                      0.2.5                h8ffe710_2    conda-forge
zipp                      3.8.0              pyhd8ed1ab_0    conda-forge
zlib                      1.2.11            h8ffe710_1014    conda-forge
zstd                      1.5.2                h6255e5f_0    conda-forge
SANTKJD commented 2 years ago

Dear author, I have trained 40000+ cifs and I have set “tarin size 0.8”,“epoch 1000”,I think I will get the same result of yours ,but I just get the MAE (0.049).Only 300~400 cifs of all cifs are different from yours.Is my result correct within the margin of error?

liaokkkkk commented 1 year ago

yeah,i also found this problem

liaokkkkk commented 1 year ago

I know, it’s because there are differences in prediction codes between models trained using GPU and models trained using CPU. You can go to some tutorials online and it’s easier to solve.