Closed OlaWod closed 2 years ago
https://github.com/microsoft/UniSpeech/blob/e3043e2021d49429a406be09b9b8432febcdec73/WavLM/WavLM.py#L320
I think this should be padding_mask = padding_mask.any(-1)
padding_mask = padding_mask.any(-1)
The argument is as follows:
Suppose I have a padded input of size (4, 90799), which consists of 4 wavs. Their lengths are 90799, 75108, 60146, 60087, respectively.
feature, mask = wavlm.extract_features(y, padding_mask=y_mask) print((1 - mask.int()).sum(-1))
Running the above code will have [283, 235, 188, 188] printed out. But [283, 234, 187, 187] is expected, because 90799 // 320 = 283, 75108 // 320 = 234, 60146 // 320 = 187, 60087 // 320 = 187.
[283, 235, 188, 188]
[283, 234, 187, 187]
90799 // 320 = 283
75108 // 320 = 234
60146 // 320 = 187
60087 // 320 = 187
def forward_padding_mask( self, features: torch.Tensor, padding_mask: torch.Tensor, ) -> torch.Tensor: # padding_mask.size() = (4, 90799) extra = padding_mask.size(1) % features.size(1) if extra > 0: padding_mask = padding_mask[:, :-extra] # padding_mask.size() = (4, 90560) padding_mask = padding_mask.view( padding_mask.size(0), features.size(1), -1 ) # padding_mask.size() = (4, 283, 320) # padding_mask[1] = # [[0, 0, 0, 0, ..., 0, 0, 0], # [0, 0, 0, 0, ..., 0, 0, 0], # ... # [0, 0, 0, 0, ..., 1, 1, 1], # ... # [1, 1, 1, 1, ..., 1, 1, 1], # [1, 1, 1, 1, ..., 1, 1, 1], padding_mask = padding_mask.all(-1) # padding_mask.size() = (4, 283) # padding_mask[1] = # [False, # False, # ... # False, # this should be 'True' # ... # True, # True], return padding_mask
Please correct me if I am wrong.
seems that the extracted feature length is not strictly input wav length // 320...
https://github.com/microsoft/UniSpeech/blob/e3043e2021d49429a406be09b9b8432febcdec73/WavLM/WavLM.py#L320
I think this should be
padding_mask = padding_mask.any(-1)
The argument is as follows:
Suppose I have a padded input of size (4, 90799), which consists of 4 wavs. Their lengths are 90799, 75108, 60146, 60087, respectively.
Running the above code will have
[283, 235, 188, 188]
printed out. But[283, 234, 187, 187]
is expected, because90799 // 320 = 283
,75108 // 320 = 234
,60146 // 320 = 187
,60087 // 320 = 187
.Please correct me if I am wrong.