JuanFMontesinos / VoViT

VoViT: Low Latency Graph-based Audio-Visual VoiceSeparation Transformer
https://ipcv.github.io/VoViT/
34 stars 9 forks source link

Training and Testing code #3

Open CarolineCheng233 opened 2 years ago

CarolineCheng233 commented 2 years ago

Can you please provide the code for the training and testing process? Meanwhile, I wonder about the detailed training setting, such as batch size. Thanks a lot!

JuanFMontesinos commented 2 years ago

Hi, I'm afraid we don't support that. The dataloader we use is somehow complex. It assigns IDs dynamically to each file and it has a wider set of options for different experiments we didn't publish. We don't provide the training code as it would be costly for us to help everyone to address the issues.

If you just want to train the model using the face landmarks and the audio, it's somehow simple to create a dataloader. You just need 2 audios + the landmarks corresponding to the target face. The audios are normalized wrt the absolute max.

The core network is the one called AudioVisualNetwork. In the forward you can find a line similar to:

 audio_feats = self.audio_processor.preprocess_audio(s1,s2)

This will create the mixtures for you. Then you just need to create the gt masks (which are bounded for stability) with the following code. Note that the spectrograms are batch, freq, time, [real,imaginary]

  def get_loss_mask(self, logits_mask, sp_mix, sources):
      spm = rearrange(sources[0], 'b c f t -> b f t c')
      sp_mix = rearrange(sp_mix, 'b c f t -> b f t c')
      # Masks are typically bounded when used as loss
      # Two types are computed:
      # loss_mask is the one which would be used in a loss function
      # inference_mask is the one that, being multiplied by the mixture, gives the indepoendent source

      if self.complex_enabled:
          # loss_mask = torch.permute(self.tanh(logits_mask), [0,2,3,1])
          loss_mask = rearrange(self.tanh(logits_mask), 'b c f t -> b f t c')
          gt_mask = self.complex_mask(spm, sp_mix)
      else:
          raise NotImplementedError(f'Not tested')
      return loss_mask, gt_mask

  def compute_loss(self, pred, gt, weight, loss_criterion='mse'):

      if self.complex_enabled or not self.loss_on_mask:
          # Complex mask and gt shape BxFxTxC, weight unsqueezing required for broadcasting
          # The same applied for direct estimation, in which the mask multiplies the mixture as real + imag
          # However ratio masks are applied over the magnitude so no broadcasting is used
          weight = weight.unsqueeze(-1)
          pred = torch.view_as_real(pred) if not self.loss_on_mask else pred
      assert pred.shape == gt.shape, 'Mask computation: Ground truth and predictions has to be the same shape'
      if loss_criterion.lower() == 'mse':
          if self.weighted_loss:
              loss = (weight * (pred - gt).pow(2)).mean()
          else:
              loss = mse_loss(pred, gt)
      elif loss_criterion.lower() == 'l1':
          if self.weighted_loss:
              loss = (weight * (pred - gt).abs()).mean()
          else:
              loss = l1_loss(pred, gt)
      elif loss_criterion.lower() == 'l1':
          raise Exception('I fucked  it up with the if statements')
      return loss

  @staticmethod
  def tanh(x):
      K = 10
      # *(1-torch.exp(-C * x))/(1+torch.exp(-C * x))
      # Compute this formula but using torch.tanh to deal with asymptotic values
      # Manually coded at https://github.com/vitrioil/Speech-Separation/blob/master/src/models/complex_mask_utils.py
      return K * torch.tanh(x)

  @torch.no_grad()
  def complex_mask(self, sp0, sp_mix, eps=torch.finfo(torch.float32).eps):
      # Bibliography about complex masks
      # http://homes.sice.indiana.edu/williads/publication_files/williamsonetal.cRM.2016.pdf
      assert sp0.shape == sp_mix.shape

      sp_mix = sp_mix + eps
      mask = complex_division(sp0, sp_mix) / self.n_sources
      mask_bounded = self.tanh(mask)
      return mask_bounded

I can give you parts of the code but I don't think it's gonna be more clear than this one (which has been polished)

The refinement network is trained once the main stage is trained as described in the paper. We used batch size 10.

Best Juan

JuanFMontesinos commented 2 years ago

the weight is.

weight = torch.log1p(magnitude_sp_mix)
weight = torch.clamp(weight, 1e-3, 10)
Sreyan88 commented 2 years ago

Hi @JuanFMontesinos ,

Thank You for the code. Great paper! I am indeed trying to replicate the training.

I had a few questions and would be grateful if you could please answer them:

1) audio_feats = self.audio_processor.preprocess_audio(s1,s2) --> where exactly in this function is the mixture of 2 audios being created? I just see the first file being read here (src[0]): https://github.com/JuanFMontesinos/VoViT/blob/499a689860d832edc3b5c744ef5d62eeddc96d3b/vovit/core/models/production_model.py#L121-L139

Do you have a different preprocess_audio while training? or am I missing something here?

2) How do you create the extra dimensions here? batch, freq, time, [real,imaginary], the [real,imaginary] parts?

Thank You!

JuanFMontesinos commented 2 years ago

Hi @Sreyan88 The code over there is just for inference (it's a simplified function). At training, we create the mixtures by mixing 2 or more audios (depending on the setup). The preprocess_audio func in training looks like:

    def preprocess_audio(self, *src: list, real_sample=False, n_sources=2):
        """
        Inputs contains the following keys:
           audio: the main audio waveform of shape N,M
           audio_acmt: the secondary audio waveform of shame N,M
           src: If using inference on real mixtures, the mixture audio waveform of shape N,M
        """

        self.n_sources = n_sources if real_sample else len(src)
        self.n_sources = self.n_sources + 1 if self.remix_input else self.n_sources

        with torch.no_grad():
            # Inference for a real mixture
            if real_sample:
                self.n_sources = n_sources
                # Inference in case of a real sample
                sp_mix_raw = self.wav2sp(src[0]).contiguous() / self.n_sources

                if self.downsample_coarse:
                    sp_mix = sp_mix_raw[:, ::2, ...]  # BxFxTx2
                elif self.downsample_interp:
                    raise NotImplementedError
                else:
                    sp_mix = sp_mix_raw
            else:

                sp = [self.wav2sp(x) for x in src]  # Spectrogram shape BxFxTx2
                assert len(sp) > 0, f'Empty list of audio inputs '

                # This is the main spectrogram, which is the target if single output modality
                # Used when remix input up
                spm = sp[0]

                if self.remix_input:
                    B = spm.shape[0]  # Batch elements
                    ndim = spm.ndim
                    coef = (torch.rand(B, *[1 for _ in range(ndim - 1)], device=spm.device) < self.remix_coef).byte()
                    # sources = sp + [spm.flip(0) * coef]
                    sources = sp + [torch.roll(spm, shifts=1, dims=0) * coef]
                else:
                    sources = sp
                sp_mix_raw = (sum(sources) / self.n_sources).contiguous()
                sources_raw = sources
                # Downsampling to save memory
                if self.downsample_coarse:
                    # Shape becomes   B x F/2 x T x 2
                    sources = [x[:, ::2, ...] for x in sources_raw]
                    sp_mix = sp_mix_raw[:, ::2, ...]
                elif self.downsample_interp:
                    raise NotImplementedError
                else:
                    sp_mix = sp_mix_raw

                mag = sp_mix.norm(dim=-1)  # Magnitude spectrogram BxFxT

                if self.weighted_loss:
                    weight = torch.log1p(mag)
                    weight = torch.clamp(weight, 1e-3, 10)
                else:
                    weight = None
            if self.complex_enabled:
                x = rearrange(sp_mix, 'b f t c -> b c f t')
                # x = torch.permute(sp_mix, [0,3,1,2])
                if not real_sample:
                    sources = [rearrange(x, 'b f t c -> b c f t') for x in sources]
            elif self.log_sp_enabled:
                epsilon = 1e-4
                x = (mag + epsilon).log().unsqueeze(1)
                if not real_sample:
                    sources = [(x.norm(dim=-1) + epsilon).log().unsqueeze(1) for x in sources]

            elif self.mel_enabled:
                x = self.sp2mel(mag).unsqueeze(1)
                sources = [x.norm(dim=-1).unsqueeze(1) for x in sources]
                raise NotImplementedError('Option not implemented in depth. Draft written.')
            else:
                x = mag.unsqueeze(1)
                if not real_sample:
                    sources = [x.norm(dim=-1).unsqueeze(1) for x in sources]
        output = {'mixture': x, 'sp_mix_raw': sp_mix_raw}
        if not real_sample:
            output.update({'src_sp_raw': sources_raw, 'sources': sources, 'weight': weight})
        return output

where the mix is created here:

                if self.remix_input:
                    B = spm.shape[0]  # Batch elements
                    ndim = spm.ndim
                    coef = (torch.rand(B, *[1 for _ in range(ndim - 1)], device=spm.device) < self.remix_coef).byte()
                    # sources = sp + [spm.flip(0) * coef]
                    sources = sp + [torch.roll(spm, shifts=1, dims=0) * coef]

We are basically computing a binary mask created from random coefficients and permuting the elements of the batch.

  1. The dimensions are created by pytorch. At the time of realeasing the code, pytorch didn't support complex numbers. They were supporting stft by providing a 2-channel tensor whose dimensions correspond to real and imaginary numbers. This behaviour is controlled (nowadays, maybe already deprecated in the newest versions) by a flag return_complex=False in the stft operator.
Sreyan88 commented 2 years ago

Hi @JuanFMontesinos ,

Thank You so much for this! I will try to keep this discussion alive till I am being able to replicate the training!

JuanFMontesinos commented 2 years ago

Sure @Sreyan88 Note that it takes ~20 days to train in a 3090. At the beggining it may take a while to start converging if you use 2 voices only as it's a hard problem.

Sreyan88 commented 1 year ago

Hi @JuanFMontesinos ,

I have two questions and would be grateful if you could answer them!

(1) Can you please provide me with the code for: complex_division() (2) I want to generate complex masks for speech dereverberation. I wanted to know what will n_sources in mask = complex_division(sp0, sp_mix) / self.n_sources be?

JuanFMontesinos commented 1 year ago

Hi @Sreyan88 complex_division is defined here:
https://github.com/JuanFMontesinos/VoViT/blob/dfd11c2ec5faa09f51cf68ae7ab42e5794941468/vovit/core/models/production_model.py#L68-L73

Regarding the masks, masks are usually defined as: $T(k,l)=M(k,l)\odot S(k,l)$ where $\odot$ denotes the complex product, T is the target source and S is the mixture. If we now guess we have two sources, $S=(T_1+T_2)/2$. Note that sources can be combined in different ways. Taking the average is just a widely used one. Some papers take the maximum or apply a weighted average.

$M=T_1/S=2*\frac{T_1}{T_1+T_2}$

In our case we divide by n_sources to remove that scaling factor. So that our masks are defined as $M=\frac{T_1}{T_1+T_2}$

The interpretation is the following. If you do not divide by the scaling factor, you are making the network to estimate how many sources are there, and compensating to recover the original loudness. If you divide by the scaling factor, the network just isolates the source at its current loudness.

Since we know the number of sources, we just proceeded as in the second case. However it's something very silly and picky, your shouldn't worry about it.

Sreyan88 commented 1 year ago

Thank You so much for your reply! Can you also please help me with your stft parameters?

JuanFMontesinos commented 1 year ago

As stated in Sec. 3.2. in the paper, audio is resampled to 16384 Hz and stft is computed with a hop size of 256 and nfft 1022