psyai-net / EmoTalk_release

This is the official source for our ICCV 2023 paper "EmoTalk: Speech-Driven Emotional Disentanglement for 3D Face Animation"
Other
350 stars 34 forks source link

RuntimeError: Mask shape should match input. mask: [4, 104, 104] input: [1, 4, 104, 104] #4

Open zhuyetuo opened 1 year ago

zhuyetuo commented 1 year ago

image

(zyt-nerf) amax@amax:~/zyt/audio2face/EmoTalk_release$ python demo.py --wav_path "./audio/disgust.wav" Some weights of Wav2Vec2Model were not initialized from the model checkpoint at jonatasgrosman/wav2vec2-large-xlsr-53-english and are newly initialized: ['wav2vec2.lm_head.bias', 'wav2vec2.lm_head.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. Some weights of Wav2Vec2ForSpeechClassification were not initialized from the model checkpoint at r-f/wav2vec-english-speech-emotion-recognition and are newly initialized: ['wav2vec2.lm_head.bias', 'wav2vec2.lm_head.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. /home/amax/miniconda3/envs/zyt-nerf/lib/python3.10/site-packages/torch/nn/modules/activation.py:1144: UserWarning: Converting mask without torch.bool dtype to bool; this will negatively affect performance. Prefer to use a boolean mask directly. (Triggered internally at ../aten/src/ATen/native/transformers/attention.cpp:150.) return torch._native_multi_head_attention( Traceback (most recent call last): File "/home/amax/zyt/audio2face/EmoTalk_release/demo.py", line 111, in main() File "/home/amax/zyt/audio2face/EmoTalk_release/demo.py", line 106, in main test(args) File "/home/amax/miniconda3/envs/zyt-nerf/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/home/amax/zyt/audio2face/EmoTalk_release/demo.py", line 30, in test prediction = model.predict(audio, level, person) File "/home/amax/zyt/audio2face/EmoTalk_release/model.py", line 140, in predict bs_out11 = self.transformer_decoder(hidden_states11, hidden_states_emo11_832, tgt_mask=tgt_mask11, File "/home/amax/miniconda3/envs/zyt-nerf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/home/amax/miniconda3/envs/zyt-nerf/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 360, in forward output = mod(output, memory, tgt_mask=tgt_mask, File "/home/amax/miniconda3/envs/zyt-nerf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/home/amax/miniconda3/envs/zyt-nerf/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 698, in forward x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)) File "/home/amax/miniconda3/envs/zyt-nerf/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 707, in _sa_block x = self.self_attn(x, x, x, File "/home/amax/miniconda3/envs/zyt-nerf/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/amax/miniconda3/envs/zyt-nerf/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1144, in forward return torch._native_multi_head_attention( RuntimeError: Mask shape should match input. mask: [4, 104, 104] input: [1, 4, 104, 104]

ZiqiaoPeng commented 9 months ago

Has the problem been solved? The latest code should not have this problem.