microsoft / UniSpeech

UniSpeech - Large Scale Self-Supervised Learning for Speech
Other
406 stars 71 forks source link

Bug in WavLM? #24

Closed OlaWod closed 2 years ago

OlaWod commented 2 years ago

https://github.com/microsoft/UniSpeech/blob/e3043e2021d49429a406be09b9b8432febcdec73/WavLM/WavLM.py#L320

I think this should be padding_mask = padding_mask.any(-1)


The argument is as follows:

Suppose I have a padded input of size (4, 90799), which consists of 4 wavs. Their lengths are 90799, 75108, 60146, 60087, respectively.

feature, mask = wavlm.extract_features(y, padding_mask=y_mask)
print((1 - mask.int()).sum(-1))

Running the above code will have [283, 235, 188, 188] printed out. But [283, 234, 187, 187] is expected, because 90799 // 320 = 283, 75108 // 320 = 234, 60146 // 320 = 187, 60087 // 320 = 187.

    def forward_padding_mask(
            self, features: torch.Tensor, padding_mask: torch.Tensor,
    ) -> torch.Tensor:
        # padding_mask.size() = (4, 90799)
        extra = padding_mask.size(1) % features.size(1)
        if extra > 0:
            padding_mask = padding_mask[:, :-extra]
        # padding_mask.size() = (4, 90560)
        padding_mask = padding_mask.view(
            padding_mask.size(0), features.size(1), -1
        )
        # padding_mask.size() = (4, 283, 320)
        # padding_mask[1] =
        # [[0, 0, 0, 0, ..., 0, 0, 0],
        #  [0, 0, 0, 0, ..., 0, 0, 0],
        #  ...
        #  [0, 0, 0, 0, ..., 1, 1, 1],
        #  ...
        #  [1, 1, 1, 1, ..., 1, 1, 1],
        #  [1, 1, 1, 1, ..., 1, 1, 1],
        padding_mask = padding_mask.all(-1)
        # padding_mask.size() = (4, 283)
        # padding_mask[1] =
        # [False,
        #  False,
        #  ...
        #  False,  # this should be 'True'
        #  ...
        #  True,
        #  True],
        return padding_mask

Please correct me if I am wrong.

OlaWod commented 2 years ago

seems that the extracted feature length is not strictly input wav length // 320...