openkim / kliff

KIM-based Learning-Integrated Fitting Framework for interatomic potentials.
https://kliff.readthedocs.io
GNU Lesser General Public License v2.1
34 stars 20 forks source link

Torch save and load functions #74

Closed adityakavalur closed 1 year ago

adityakavalur commented 1 year ago

PyTorch allows saving the state of a model that can be re-read for additional training (or evaluation). This workflow has function calls in KLIFF under kliff/models/model_torch.py as well. However, the functionality does not seem to be working as expected.

For re-creating the issue you can use one of the KLIFF examples, I am using the linear regression example. Simply add the following lines to the script in an attempt to re-read the pickle file

model2 = LinearRegression(descriptor)
model_path = "./linear_model.pkl"
model2.load(model_path)

The error it returns is

size mismatch for layer.weight: copying a param with shape torch.Size([30]) from checkpoint, the shape in current model is torch.Size([1, 30])

It looks like the bug is part of the Linear layer being read in here

I tried torch versions 1.10.x and 1.12.x to no avail. I am pretty new to ML so please let me know if I am misunderstanding something.

Thanks for developing this tool! It is a useful utility!

mjwen commented 1 year ago

Hi @adityakavalur! It seems not a PyTorch problem. Are you adding the lines to the bottom of the the liner regression example or somewhere else?

adityakavalur commented 1 year ago

Thanks for looking at this @mjwen! Yes, I am adding the lines at the bottom of the example script, here is the complete script.

"""
.. _tut_linear_regression:

Train a linear regression potential
===================================

In this tutorial, we train a linear regression model on the descriptors obtained using the
symmetry functions.
"""

from kliff.calculators import CalculatorTorch
from kliff.dataset import Dataset
from kliff.descriptors import SymmetryFunction
from kliff.models import LinearRegression
from kliff.utils import download_dataset
from kliff.utils import pickle_load

descriptor = SymmetryFunction(
    cut_name="cos", cut_dists={"Si-Si": 5.0}, hyperparams="set30", normalize=True
)

model = LinearRegression(descriptor)

# training set
dataset_path = download_dataset(dataset_name="Si_training_set")
dataset_path = dataset_path.joinpath("varying_alat")
tset = Dataset(dataset_path)
configs = tset.get_configs()

# calculator
calc = CalculatorTorch(model)
calc.create(configs, reuse=False)

##########################################################################################
# We can train a linear regression model by minimizing a loss function as discussed in
# :ref:`tut_nn`. But linear regression model has analytic solutions, and thus we can train
# the model directly by using this feature. This can be achieved by calling the ``fit()``
# function of its calculator.
#

# fit the model
calc.fit()

# save model
model.save("linear_model.pkl")

#AK: model load
model2 = LinearRegression(descriptor)
model_path = "./linear_model.pkl"
model2.load(model_path)

This issue is likely stemming from PyTorch or how PyTorch is used in KLIFF , I have tried a few versions of PyTorch (1.10.x, 1.12.x), Python (3.8.x, 3.9.x) and OS (Centos, Rocky), but always get the same error. Can you share what combination of the above works for you using the above example?

I don't see this error when I use PyTorch directly, which is why I raised an issue in this repo. I created an example from the PyTorch documentation here that uses the same functionalities in torch i.e. saving and loading a state_dict, and that seems to work fine.

import torch
import torch.nn as nn
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netA = Net()
netB = Net()

optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9)
optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9)
PATH = "model.pt"

torch.save({
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

print(modelA.state_dict)

modelA.eval()
modelB.eval()
mjwen commented 1 year ago

Thank you @adityakavalur for the detailed debug info!

Yes, this is a bug in the LinearRegression model in kliff, and it has been fixed in #75. Thank you very much for reporting this!

adityakavalur commented 1 year ago

Awesome! Thanks so much for the quick fix @mjwen! I can confirm that this works.