NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.36k stars 1.39k forks source link

Support torch.distributions for fp16 operations #812

Open jramapuram opened 4 years ago

jramapuram commented 4 years ago

Been trying to get a VAE to work with APEX.

Ran into a few issues along the way and tried to pull in lessons learned from https://github.com/chainer/chainer/issues/6168 (eg: no exp on logvar as in this example), however something that seems to be missing is proper wrapping of torch.distributions including operands such as the D.kl_divergence.

As an example here is the Normal : Normal KLD which I can't seem get to not blow up.
The only thing I see that could tentatively blow up here is the .pow(2). (Note that the division of by .scale is not an issue because people typically add some tolerance to the scale, eg: scale += 1e-6 which I have tried to truncate to 3 bits for fp16, but to no avail.)

@register_kl(Normal, Normal)
def _kl_normal_normal(p, q):
    var_ratio = (p.scale / q.scale).pow(2)
    t1 = ((p.loc - q.loc) / q.scale).pow(2)
    return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())

I have got to the point where I can train a VAE without the KL term (i.e. basically an autoencoder with reparameterization but no regularization on the latent variable) but the KL is still an issue.

I have already tried -o1 and -o2 where -o2 fails to work due to an error with returning fp32 values during reparameterization. Forcibly type-casting as follows does not work:

    def _reparametrize_gaussian(self, mu, logvar, force=False):
        """ Internal member to reparametrize gaussian.

        :param mu: mean logits
        :param logvar: log-variance.
        :returns: reparameterized tensor and param dict
        :rtype: torch.Tensor, dict

        """
        if self.training or force:  # returns a stochastic sample for training
            std = logvar.mul(0.5)  # .exp()
            eps = torch.zeros_like(logvar).normal_().type(std.dtype)
            nan_check_and_break(logvar, "logvar")
            reparam_sample = eps.mul(std).add_(mu)
            return reparam_sample, {'z': reparam_sample, 'mu': mu, 'logvar': logvar}

        return mu, {'z': mu, 'mu': mu, 'logvar': logvar}
jpatrickpark commented 4 years ago

@jramapuram It might be an issue of exploding KL divergence (a known problem even with full precision) which can be mitigated with using epsilon value. It seems to help me avoid overflow with O2 optimization level.

For example: https://github.com/rwightman/pytorch-image-models/blob/3eb4a96edaecd272305fdf97fe50b22939153687/timm/loss/jsd.py#L35-L38

jramapuram commented 4 years ago

Good call @jpatrickpark, will give that a shot. Everything works fine for me in fp32 (even without grad clipping / regularization such as spectral norm on the model).