yuangan / EAT_code

Official code for ICCV 2023 paper: "Efficient Emotional Adaptation for Audio-Driven Talking-Head Generation".
Other
269 stars 30 forks source link

Dimension not matching when training Emotional Adaption #18

Closed phphuc612 closed 7 months ago

phphuc612 commented 8 months ago

Hi @yuangan , I have troubles in training Emotional Adaption with 1 GPU and the runtime errors were found due to mismatching dimension. Thank you for your great work and your time to help me out.

Environment diff from README:

Errors

1. Mismatch shapes in face_feature_map

Traceback (most recent call last):
  File "prompt_st_dp_eam3d.py", line 129, in <module>
    train(config, generator, discriminator, kp_detector, audio2kptransformer, emotionprompt, sidetuning, opt.checkpoint, log_dir, dataset, opt.device_ids)
  File "/home/phphuc/Desktop/EAT_code/train_transformer.py", line 272, in train_batch_deepprompt_eam3d_sidetuning
    losses_generator, generated = generator_full(x, train_params['train_with_img'])
  File "/home/phphuc/anaconda3/envs/eat/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/phphuc/anaconda3/envs/eat/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 166, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/phphuc/anaconda3/envs/eat/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/phphuc/Desktop/EAT_code/modules/model_transformer.py", line 781, in forward
    he_driving_emo, input_st = self.audio2kptransformer(x, kp_canonical, emoprompt=emoprompt, deepprompt=deepprompt, side=True)           # {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
  File "/home/phphuc/anaconda3/envs/eat/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/phphuc/Desktop/EAT_code/modules/transformer.py", line 807, in forward
    face_feature_map.repeat(bs, seqlen, 1, 1, 1).reshape(bs * seqlen, 32, 64, 64)),
RuntimeError: shape '[55, 32, 64, 64]' is invalid for input of size 28835840

2. Unexpected case

yuangan commented 7 months ago

Thank you for your attention. The batch_size should match the number of your GPU. For instance, with 4 GPUs, the batch_size should be 4. Due to the impact of the frame length seqlen, training EAT with a batch_size greater than 1 on one GPU has not been implemented.