hyperdimensional-computing / torchhd

Torchhd is a Python library for Hyperdimensional Computing and Vector Symbolic Architectures
https://torchhd.readthedocs.io
MIT License
238 stars 24 forks source link

How to save a encoder #169

Closed yinchen1233 closed 5 months ago

yinchen1233 commented 5 months ago
class Encoder(nn.Module):
    def __init__(self, out_features, size, levels):
        super(Encoder, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.position = embeddings.Random(size * size, out_features)
        self.value = embeddings.Level(levels, out_features)

    def forward(self, x):
        x = self.flatten(x)
        sample_hv = torchhd.bind(self.position.weight, self.value(x))
        sample_hv = torchhd.multiset(sample_hv)   
        return torchhd.hard_quantize(sample_hv)   

If we want to use a saved model and need an encoder with the same parameters, how do we save the encoder?

mikeheddes commented 5 months ago

Since the embedding classes all inherit from nn.Module and register their parameters, you can use PyTorch's build-in method for saving model state. See the example below:


import torchhd, torch
from torch import nn
from torchhd import embeddings

class Encoder(nn.Module):
    def __init__(self, out_features, size, levels):
        super(Encoder, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.position = embeddings.Random(size * size, out_features)
        self.value = embeddings.Level(levels, out_features)

    def forward(self, x):
        x = self.flatten(x)
        sample_hv = torchhd.bind(self.position.weight, self.value(x))
        sample_hv = torchhd.multiset(sample_hv)   
        return torchhd.hard_quantize(sample_hv) 

enc = Encoder(1000, 28, 100)

print(enc.state_dict())
# OrderedDict([('position.weight',
#              MAPTensor([[ 1.,  1.,  1.,  ..., -1.,  1., -1.],
#                         [ 1.,  1., -1.,  ..., -1., -1.,  1.],
#                         [-1.,  1.,  1.,  ...,  1., -1.,  1.],
#                         ...,
#                         [ 1.,  1., -1.,  ...,  1., -1.,  1.],
#                         [ 1.,  1., -1.,  ...,  1.,  1., -1.],
#                         [ 1.,  1.,  1.,  ..., -1., -1., -1.]])),
#              ('value.weight',
#               MAPTensor([[-1., -1.,  1.,  ..., -1.,  1.,  1.],
#                         [-1., -1.,  1.,  ..., -1.,  1.,  1.],
#                         [-1., -1.,  1.,  ..., -1.,  1.,  1.],
#                         ...,
#                         [-1.,  1.,  1.,  ..., -1.,  1.,  1.],
#                         [-1.,  1.,  1.,  ..., -1.,  1.,  1.],
#                         [-1.,  1.,  1.,  ..., -1.,  1.,  1.]]))])

# save the model parameters
torch.save(enc.state_dict(), "model_weights.pt")

# load the model parameters
enc.load_state_dict(torch.load("model_weights.pt"))

I hope this helps, let me know if you are facing any problems.