Rayhane-mamah / Efficient-VDVAE

Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more"
https://arxiv.org/abs/2203.13751
MIT License
188 stars 21 forks source link

Applying DiscMixLogistic to non-image data #14

Open laplaceon opened 1 year ago

laplaceon commented 1 year ago

This isn't strictly about VDVAEs but rather about the DiscMixLogistic loss and applying it to non-image data, specifically multi-channel sequences such as stereo audio.

I modified all the tensor slicing operations to have 1 less dimension and set num_classes to 65535, min_pix_value and max_pix to -1 and 1 respectively. I also removed the parts for whole pixel conditioning (RGB AR code).

def discretized_mix_logistic_loss(logits, targets, num_classes=65535, num_mixtures=8):
      # Shapes:
      #    targets: B, C, L
      #    logits: B, M * (3 * C + 1), L

      assert len(targets.shape) == 3
      B, C, L = targets.size()

      min_pix_value, max_pix_value = -1, 1

      targets = targets.unsqueeze(2)  # B, C, 1, L

      logit_probs = logits[:, :num_mixtures, :]  # B, M, L
      l = logits[:, num_mixtures:, :]  # B, M*C*3 , L
      l = l.reshape(B, C, 3 * num_mixtures, L)  # B, C, 3 * M, L

      model_means = l[:, :, :num_mixtures, :]  # B, C, M, L

      inv_stdv, log_scales = _compute_inv_stdv(
          l[:, :, num_mixtures: 2 * num_mixtures, :], distribution_base='logstd')

      # model_coeffs = torch.tanh(
      #     l[:, :, 2 * num_output_mixtures: 3 * num_output_mixtures, :])  # B, C, M, H, W

      centered = targets - model_means  # B, C, M, L

      plus_in = inv_stdv * (centered + 1. / num_classes)
      cdf_plus = torch.sigmoid(plus_in)
      min_in = inv_stdv * (centered - 1. / num_classes)
      cdf_min = torch.sigmoid(min_in)

      log_cdf_plus = plus_in - F.softplus(plus_in)  # log probability for edge case of 0 (before scaling)
      log_one_minus_cdf_min = -F.softplus(min_in)  # log probability for edge case of 255 (before scaling)

      # probability for all other cases
      cdf_delta = cdf_plus - cdf_min  # B, C, M, L

      mid_in = inv_stdv * centered
      # log probability in the center of the bin, to be used in extreme cases
      # (not actually used in this code)
      log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)

      # the original implementation uses samples > 0.999, this ignores the largest possible pixel value (255)
      # which is mapped to 0.9922
      broadcast_targets = torch.broadcast_to(targets, size=[B, C, num_mixtures, L])
      log_probs = torch.where(broadcast_targets == min_pix_value, log_cdf_plus,
                              torch.where(broadcast_targets == max_pix_value, log_one_minus_cdf_min,
                                          torch.where(cdf_delta > 1e-5,
                                                      torch.log(torch.clamp(cdf_delta, min=1e-12)),
                                                      log_pdf_mid - np.log(num_classes / 2))))  # B, C, M, L

      log_probs = torch.sum(log_probs, dim=1) + F.log_softmax(logit_probs, dim=1)  # B, M, L
      negative_log_probs = -torch.logsumexp(log_probs, dim=1)  # B, L

      return negative_log_probs

When training, the loss barely improves every epoch and is very high (over 100). I'm wondering if any changes I made are causing this issue or if there are additional changes to be made. Also, the sequence lengths are in the thousands (>= 8192) so maybe this is expected behavior?

Rayhane-mamah commented 1 year ago

Hi @laplaceon Thanks for reaching out with your question.

This brings back some memories from back when I did text to speech :) Let me compile a couple of resources, write up a nice answer and come back to you.

Sorry for the delay!

laplaceon commented 1 year ago

Hello Rayhane. I thought some more about the question I gave and realized there is a contradiction in my loss function. Previously my loss function prior to adding DiscMixLogistic was L1/L2 distance between spectrograms of my output and target waveforms. I read a comment you made in another issue about how training on L2 loss causes over-regularization so I added DiscMixLogistic loss between the output and target waveforms. What I should have done instead was modify the original loss function from L2 loss on the spectrograms to DiscMixLogistic loss on the spectrograms. Then, I can avoid the over-regularization and I would be working with tensors with the same dimensionality as images.

Of course, there is a slight problem I've been caught up on which is how to translate the mean, variance, and weight logits from waveforms to spectrograms so that backpropagation can be applied. The means are simple to translate since they are just the MelSpectrogram transformation applied to the waveform means, or so I think. I'm not quite sure how the variances and mixture weights can be translated into the spectrogram domain, however.

But if I do figure it out, I can use the DiscMixLogistic loss basically unchanged. It would just be calculated on "images" with C != 3.