facebookresearch / audiocraft

Audiocraft is a library for audio processing and generation with deep learning. It features the state-of-the-art EnCodec audio compressor / tokenizer, along with MusicGen, a simple and controllable music generation LM with textual and melodic conditioning.
MIT License
20.17k stars 2.01k forks source link

MAGNeT: Invalid shape for attention bias #397

Closed timothelaborie closed 5 months ago

timothelaborie commented 5 months ago

I ran the example code here https://huggingface.co/facebook/magnet-medium-30secs

It crashes with this log:

wav = model.generate(descriptions) # generates 2 samples. File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\audiocraft\models\genmodel.py", line 161, in generate tokens = self._generate_tokens(attributes, prompt_tokens, progress) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\audiocraft\models\genmodel.py", line 228, in _generate_tokens gen_tokens = self.lm.generate( File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\audiocraft\models\lm_magnet.py", line 131, in generate return self._generate_magnet(prompt=prompt, File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\audiocraft\models\lm_magnet.py", line 232, in _generate_magnet gen_sequence, curr_step = self._generate_stage(gen_sequence, File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\audiocraft\models\lm_magnet.py", line 372, in _generate_stage all_logits = model(sequence, [], condition_tensors, stage=stage) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\audiocraft\models\lm.py", line 257, in forward out = self.transformer(input_, cross_attention_src=cross_attention_input, File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\audiocraft\modules\transformer.py", line 705, in forward x = self._apply_layer(layer, x, *args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\audiocraft\modules\transformer.py", line 662, in _apply_layer return layer(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\audiocraft\modules\transformer.py", line 557, in forward self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\nn\modules\transformer.py", line 715, in _sa_block x = self.self_attn(x, x, x, File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\audiocraft\modules\transformer.py", line 413, in forward x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p) File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\xformers\ops\fmha\__init__.py", line 223, in memory_efficient_attention return _memory_efficient_attention( File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\xformers\ops\fmha\__init__.py", line 321, in _memory_efficient_attention return _memory_efficient_attention_forward( File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\xformers\ops\fmha\__init__.py", line 334, in _memory_efficient_attention_forward inp.validate_inputs() File "C:\ProgramData\Anaconda3\envs\py310\lib\site-packages\xformers\ops\fmha\common.py", line 151, in validate_inputs raise ValueError( ValueError: Invalid shape for attention bias: torch.Size([1500, 1500]) (expected (4, 24, 1500, 1500)) query.shape: torch.Size([4, 1500, 24, 64]) key.shape : torch.Size([4, 1500, 24, 64]) value.shape: torch.Size([4, 1500, 24, 64])

lonzi commented 5 months ago

See: https://github.com/facebookresearch/audiocraft/issues/390