Closed yinchen1233 closed 5 months ago
Since the embedding classes all inherit from nn.Module and register their parameters, you can use PyTorch's build-in method for saving model state. See the example below:
import torchhd, torch
from torch import nn
from torchhd import embeddings
class Encoder(nn.Module):
def __init__(self, out_features, size, levels):
super(Encoder, self).__init__()
self.flatten = torch.nn.Flatten()
self.position = embeddings.Random(size * size, out_features)
self.value = embeddings.Level(levels, out_features)
def forward(self, x):
x = self.flatten(x)
sample_hv = torchhd.bind(self.position.weight, self.value(x))
sample_hv = torchhd.multiset(sample_hv)
return torchhd.hard_quantize(sample_hv)
enc = Encoder(1000, 28, 100)
print(enc.state_dict())
# OrderedDict([('position.weight',
# MAPTensor([[ 1., 1., 1., ..., -1., 1., -1.],
# [ 1., 1., -1., ..., -1., -1., 1.],
# [-1., 1., 1., ..., 1., -1., 1.],
# ...,
# [ 1., 1., -1., ..., 1., -1., 1.],
# [ 1., 1., -1., ..., 1., 1., -1.],
# [ 1., 1., 1., ..., -1., -1., -1.]])),
# ('value.weight',
# MAPTensor([[-1., -1., 1., ..., -1., 1., 1.],
# [-1., -1., 1., ..., -1., 1., 1.],
# [-1., -1., 1., ..., -1., 1., 1.],
# ...,
# [-1., 1., 1., ..., -1., 1., 1.],
# [-1., 1., 1., ..., -1., 1., 1.],
# [-1., 1., 1., ..., -1., 1., 1.]]))])
# save the model parameters
torch.save(enc.state_dict(), "model_weights.pt")
# load the model parameters
enc.load_state_dict(torch.load("model_weights.pt"))
I hope this helps, let me know if you are facing any problems.
If we want to use a saved model and need an encoder with the same parameters, how do we save the encoder?