masa-su / pixyz

A library for developing deep generative models in a more concise, intuitive and extendable way
https://pixyz.io
MIT License
491 stars 41 forks source link

Fix features expand: expand() -> expand().clone() #110

Closed rnagumo closed 4 years ago

rnagumo commented 4 years ago

I found a bug regarding the saving/loading Pixyz object.

1. Problem

First, I make the Normal distribution instance, and save its parameters by torch.save().

>>> import torch
>>> from pixyz.distributions import Normal
>>> z_dim = 2
>>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), features_shape=[z_dim])
>>> torch.save(p.state_dict(), "./tmp.pt")

Next, when I load the saved file with the same class object, it raises a RuntimeError. The error message tells that the parameter's dimensions in the model and those in the checkpoint are different, although both seem to be the same size torch.Size([1, 2]).

>>> q = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), features_shape=[z_dim])
>>> q.load_state_dict(torch.load("./tmp.pt"))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-0b2e959a1927> in <module>
----> 1 q.load_state_dict(torch.load("./tmp.pt"))

~/pixyz/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    828         if len(error_msgs) > 0:
    829             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 830                                self.__class__.__name__, "\n\t".join(error_msgs)))
    831         return _IncompatibleKeys(missing_keys, unexpected_keys)
    832 

RuntimeError: Error(s) in loading state_dict for Normal:
        While copying the parameter named "loc", whose dimensions in the model are torch.Size([1, 2]) and whose dimensions in the checkpoint are torch.Size([1, 2]).
        While copying the parameter named "scale", whose dimensions in the model are torch.Size([1, 2]) and whose dimensions in the checkpoint are torch.Size([1, 2]).

I test the other implementation of Normal distribution. The following is also valid Normal distribution with the same dimension, and it correctly loads the saved parameters.

>>> q = Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim))
>>> q.load_state_dict(torch.load("./tmp.pt"))
<All keys matched successfully>

2. Change

This is because of the features.expand() method in _check_features_shape() method, which is called when an object is created. When the tensor size of the given parameter is empty, DistributionBase class automatically expands its dimension without memory allocation.

ref) https://pytorch.org/docs/stable/tensors.html#torch.Tensor.expand

However, once the parameters are saved into the checkpoint file, it seems to need full memory allocation when loading the saved tensors (no reference found).

Therefore, I added the clone() method when expanding the feature's shape. Although it wastes a little memory, it correctly works.

>>> q = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), features_shape=[z_dim]) 
>>> q.load_state_dict(torch.load("./tmp.pt"))
<All keys matched successfully>
>>> q = Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim))
>>> q.load_state_dict(torch.load("./tmp.pt"))
<All keys matched successfully>

It is my pleasure if this pull request would help you.

Thank you.

ktaaaki commented 4 years ago

Thank you not only for bug reports, but also for easy-to-read pull requests! After reading your comment, I found the bug fix can be generalized.

If you specify an expanded tensor as a parameter like Normal(loc=torch.tensor(0.).expand(1,2), scale=torch.ones(1,2)), the error still returns. How about converting the tensor passed to torch.nn.Module.register_buffer to contiguous as follows ? :

    def _check_features_shape(self, features):
        # scalar
        if features.size() == torch.Size():
            features = features.expand(self.features_shape)

        if self.features_shape == torch.Size():
            self._features_shape = features.shape

        # for the issue of torch.load (#110)
        if not features.is_contiguous():
            features = features.contiguous()

        if features.size() == self.features_shape:
            batches = features.unsqueeze(0)
            return batches

        raise ValueError("the shape of a given parameter {} and features_shape {} "
                         "do not match.".format(features.size(), self.features_shape))
rnagumo commented 4 years ago

Thank you for your reply. Your suggestion seems more general, so I changed the code. I checked that all the following implementation of Normal distribution could load the saved parameters.

>>> q = Normal(loc=torch.tensor(0.).expand(2), scale=torch.ones(2))
>>> q.load_state_dict(torch.load("./tmp.pt"))                                             
<All keys matched successfully>
>>> q = Normal(loc=torch.tensor(0.), scale=torch.ones(2), features_shape=[2])
>>> q.load_state_dict(torch.load("./tmp.pt"))
<All keys matched successfully>