Open Motherboard opened 8 years ago
You can find it here - https://github.com/hughperkins/pytorch/blob/master/test/test_save_load.py
from __future__ import print_function, division
import PyTorch
import PyTorchAug
import numpy as np
# from test.test_helpers import myeval
def test_save_load():
np.random.seed(123)
a_np = np.random.randn(3, 2).astype(np.float32)
a = PyTorch.asFloatTensor(a_np)
print('a', a)
filename = '/tmp/foo.t7' # TODO: should use tempfile to get this
PyTorchAug.save(filename, a)
b = PyTorchAug.load(filename)
print('type(b)', type(b))
print('b', b)
assert np.abs(a_np - b.asNumpyTensor()).max() < 1e-4
I think this is the test I was referring to when I've said I found one - But it doesn't load weights into a model, it just saves a numpy array and loads it again.
I don't think there is any direct command to load weights from a .t7 file to a torch model in PyTorch/PyTorchAug. You can definitely load a model with existing weights using PyTorchAug.load().
For this particular task you need to copy weights for each layer from the loaded file into the existing model layers. There is no tensor.copy(newWeightTensor)
method in PyTorch tensors so you need to use tensor.fill_(1).cmul(newWeightTensor)
as a workaround. I have used it in past and it works but I couldn't find any method to change tensor type say from double
to float
.
Well, that's the reason I think there is no test case around that because there is no direct implementation of copying over the weights but it should definitely be implemented.
Hi, I see PyTorchAug.load() was recently added, and there's a test saving a numpy array, loading it and checking they're the same.
But there's no test checking a network is behaving as it should after inserting the weights from a .t7 into it, which I think is a crucial test - and can double as a usage example if one wants to load a pre-trained model and predict with it.
In the meantime, can someone enlighten me as to how to do it? (i.e. load the lua model, than insert the .t7 weights into it's layers, load an image, and call predict)