pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.5k stars 644 forks source link

Cannot batch inference with WavLM in `torchaudio.pipelines` #3700

Open BakerBunker opened 10 months ago

BakerBunker commented 10 months ago

🐛 Describe the bug

Batch inference with WavLM triggers AssertionError in WavLMSelfAttention module.

import torchaudio
wavlm=torchaudio.pipelines.WAVLM_LARGE.get_model().cuda()
wavlm.extract_features(torch.randn(2,16000,device='cuda'),lengths=torch.tensor([2000,3000],device='cuda'),num_layers=1)

Log:

AssertionError                            Traceback (most recent call last)
[<ipython-input-43-11ed28e9b1d6>](https://localhost:8080/#) in <cell line: 3>()
      1 import torchaudio
      2 wavlm=torchaudio.pipelines.WAVLM_LARGE.get_model().cuda()
----> 3 wavlm.extract_features(torch.randn(2,16000,device='cuda'),lengths=torch.tensor([2000,3000],device='cuda'),num_layers=1)

8 frames
[/usr/local/lib/python3.10/dist-packages/torchaudio/models/wav2vec2/model.py](https://localhost:8080/#) in extract_features(self, waveforms, lengths, num_layers)
     82         """
     83         x, lengths = self.feature_extractor(waveforms, lengths)
---> 84         x = self.encoder.extract_features(x, lengths, num_layers)
     85         return x, lengths
     86 

[/usr/local/lib/python3.10/dist-packages/torchaudio/models/wav2vec2/components.py](https://localhost:8080/#) in extract_features(self, features, lengths, num_layers)
    508     ) -> List[Tensor]:
    509         x, masks = self._preprocess(features, lengths)
--> 510         return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
    511 
    512 

[/usr/local/lib/python3.10/dist-packages/torchaudio/models/wav2vec2/components.py](https://localhost:8080/#) in get_intermediate_outputs(self, x, attention_mask, num_layers)
    457         x = self._preprocess(x)
    458         for layer in self.layers:
--> 459             x, position_bias = layer(x, attention_mask, position_bias=position_bias)
    460             ret.append(x)
    461             if num_layers is not None and len(ret) >= num_layers:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/torchaudio/models/wav2vec2/components.py](https://localhost:8080/#) in forward(self, x, attention_mask, position_bias, key_padding_mask)
    387             x = self.layer_norm(x)
    388 
--> 389         x, position_bias = self.attention(
    390             x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask
    391         )

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/torchaudio/models/wav2vec2/wavlm_attention.py](https://localhost:8080/#) in forward(self, query, key_padding_mask, attention_mask, position_bias)
    163         bsz, seq_len, embed_dim = query.size()
    164         assert embed_dim == self.embed_dim
--> 165         assert attention_mask is None
    166 
    167         if self.rel_attn_embed is not None and position_bias is None:

AssertionError:

Versions

PyTorch version: 2.1.0+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: 14.0.0-1ubuntu1.1 CMake version: version 3.27.7 Libc version: glibc-2.35

Python version: 3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0] (64-bit runtime) Python platform: Linux-5.15.120+-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: Tesla T4 Nvidia driver version: 525.105.17 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 46 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 2 On-line CPU(s) list: 0,1 Vendor ID: GenuineIntel Model name: Intel(R) Xeon(R) CPU @ 2.00GHz CPU family: 6 Model: 85 Thread(s) per core: 2 Core(s) per socket: 1 Socket(s): 1 Stepping: 3 BogoMIPS: 4000.42 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities Hypervisor vendor: KVM Virtualization type: full L1d cache: 32 KiB (1 instance) L1i cache: 32 KiB (1 instance) L2 cache: 1 MiB (1 instance) L3 cache: 38.5 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0,1 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Mitigation; PTE Inversion Vulnerability Mds: Vulnerable; SMT Host state unknown Vulnerability Meltdown: Vulnerable Vulnerability Mmio stale data: Vulnerable Vulnerability Retbleed: Vulnerable Vulnerability Spec store bypass: Vulnerable Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers Vulnerability Spectre v2: Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Vulnerable

Versions of relevant libraries: [pip3] numpy==1.23.5 [pip3] torch==2.1.0+cu118 [pip3] torchaudio==2.1.0+cu118 [pip3] torchdata==0.7.0 [pip3] torchsummary==1.5.1 [pip3] torchtext==0.16.0 [pip3] torchvision==0.16.0+cu118 [pip3] triton==2.1.0 [conda] Could not collect