Open huangruizhe opened 2 months ago
The current implementation assumes batch size is one, when attaching the star dimension: https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/pipelines/_wav2vec2/utils.py#L41
star
However, the underlying Wav2vec model supports batch size greater than one. So this line should instead be:
star_dim = torch.zeros((output.size(0), output.size(1), 1), dtype=output.dtype, device=output.device)
The current implementation assumes batch size is one, when attaching the
star
dimension: https://github.com/pytorch/audio/blob/ea437b31ce316ea3d66fe73768c0dcb94edb79ad/src/torchaudio/pipelines/_wav2vec2/utils.py#L41However, the underlying Wav2vec model supports batch size greater than one. So this line should instead be: