acids-ircam / ddsp_pytorch

Implementation of Differentiable Digital Signal Processing (DDSP) in Pytorch
Apache License 2.0
448 stars 56 forks source link

Question in the multi-scale loss function #3

Closed wayne391 closed 4 years ago

wayne391 commented 4 years ago

Hi, thanks for your awesome implementation!

I am working on tracing codes and I have some questions in the loss function. In the original paper, the loss is computed between the magnitudes of STFT.

However, in your codes:

https://github.com/acids-ircam/ddsp_pytorch/blob/master/code/ddsp/loss.py#L21

# Lambda for computing squared amplitude
amp = lambda x: x[:,:,0]**2 + x[:,:,1]**2

Because the shape of the STFT tensor is (batch_size, n_fft, n_frames, 2), it seems the output of the lambda function is the sum of the square of the first two frames. In my thoughts, the function should be rewritten like:

# Lambda for computing squared amplitude
amp = lambda x: x[:,:,:,0]**2 + x[:,:,:,1]**2

The return value is the sums of squares of the real and imaginary parts. Am I correct?


Here is my codes for debugging:

# ddsp/loss.py
# class: MSSTFTLoss
# function: forward

def forward(self, x, stfts_orig):
        stfts = []
        print('\n\n== [Loss] ==\n')
        print('loss input:', x.shape)

        print('\n=============\n')
        for i, scale in enumerate(self.scales):   
            cur_fft = torch.stft(x, n_fft=scale, window=self.windows[i], hop_length=int((1-self.overlap)*scale), center=False)

            print('scale:', scale)
            print(' > output stft:', cur_fft.shape)
            print(' > output stft shape:', cur_fft[:,:,0].shape, cur_fft[:,:,1].shape)
            print(' > output stft (amp):', amp(cur_fft).shape)
            print('   ---   ')
            stfts.append(amp(cur_fft))

        print('\n=============\n')

        # Compute loss 
        lin_loss = sum([torch.mean(abs(stfts_orig[i][j] - stfts[i][j])) for j in range(len(stfts[i])) for i in range(len(stfts))])
        log_loss = sum([torch.mean(abs(torch.log(stfts_orig[i][j] + 1e-4) - torch.log(stfts[i][j] + 1e-4)))  for j in range(len(stfts[i])) for i in range(len(stfts))])
        return lin_loss + log_loss

And the snapshot: image

Nintorac commented 4 years ago

It is also possible to reference the final dimension like so

# Lambda for computing squared amplitude
amp = lambda x: x[...,0]**2 + x[...,1]**2
caillonantoine commented 4 years ago

Hi, thank you for issue, everything should work now !

linzwatt commented 4 years ago

should this line be updated also? https://github.com/acids-ircam/ddsp_pytorch/blob/master/code/ddsp/analysis.py#L10

this is used during dataset preprocessing, to find the stfts at various scales

also, during training the autoencoder, why is x - input, 32000 frames long, but x_tilde - output, is 64000 frames long?

i found this while fixing the amp lambda in analysis.py, the sizes of the stfts don't line up because the output is 2x longer than input, and preprocessed stfts are computed from the input

wayne391 commented 4 years ago

@linzwatt Hi, I remember it's because of the reverb module. If you remove the reverb, the shape of input and output tensors will be the same (32000 frames, i.e. 2 seconds).

I am also confusing about that line in "analysis.py", too. Waiting for the author's reply~

linzwatt commented 4 years ago

oh I see, that makes sense.

in the case of reverb on, it seem to me that reverb should be applied to the data during pre-processing, so that the multiscale stft loss is learning to reconstruct the trailing reverb?

i.e. the decay after the end of the 32000 patch?

Yes I think the line in analysis.py needs to be updated, or the function moved into a utils file, so that both the loss and the preprocessing are computing the same multiscale stfts

andreser09 commented 4 years ago

Hi. Did you find a workaround this last issue? Because I'm getting an error that's related (I think):

Traceback (most recent call last): File "train.py", line 178, in losses[i, 0] = model.train_epoch(train_loader, loss, optimizer, args) File "/home/code/model.py", line 78, in train_epoch rec_loss = loss(x_tilde, y) / float(x.shape[1] x.shape[2]) File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in call result = self.forward(input, **kwargs) File "/home/code/ddsp/loss.py", line 61, in forward lin_loss = sum([torch.mean(abs(stfts_orig[i][j] - stfts[i][j])) for j in range(len(stfts[i])) for i in range(len(stfts))]) File "/home/code/ddsp/loss.py", line 61, in lin_loss = sum([torch.mean(abs(stfts_orig[i][j] - stfts[i][j])) for j in range(len(stfts[i])) for i in range(len(stfts))]) RuntimeError: The size of tensor a (1997) must match the size of tensor b (3997) at non-singleton dimension 1

The dimensions of the tensors are the following: stft: torch.Size([64, 33, 3997, 2]) stft_origs: torch.Size([64, 33, 1997, 2])