pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.43k stars 635 forks source link

Using MMS model with `star` token for batch size > 1 #3772

Open huangruizhe opened 2 months ago

huangruizhe commented 2 months ago

The current implementation assumes batch size is one, when attaching the star dimension: https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/pipelines/_wav2vec2/utils.py#L41

However, the underlying Wav2vec model supports batch size greater than one. So this line should instead be:

star_dim = torch.zeros((output.size(0), output.size(1), 1), dtype=output.dtype, device=output.device)