YuanxunLu / LiveSpeechPortraits

Live Speech Portraits: Real-Time Photorealistic Talking-Head Animation (SIGGRAPH Asia 2021)
MIT License
1.16k stars 200 forks source link

Error while computing GMM Log Loss when predict_length = 5 #42

Closed icedoom888 closed 2 years ago

icedoom888 commented 2 years ago

When training the audio2headpose model using the default options provided in the code (predict_length = 5 ), the code fail when computing the GMM Log Loss. The error is the following: RuntimeError: The expanded size of the tensor (12) must match the existing size (60) at non-singleton dimension 3. Target sizes: [32, 100, 1, 12]. Tensor sizes: [32, 100, 1, 60]

The GMMLogLoss takes as input two tensors:

output with a size of [32, 100, 25] -> [batch_size, time_frame_length, (2 A2H_GMM_ndim + 1) A2H_GMM_ncenter] target with a size of [32, 100, 60] -> [batch_size, time_frame_length, predict_length * 6]

Can you please explain why the output models outputs 25 values? I understand that you want to output 12 values (pose and velocity) and for each you want to output mu and sigma (24 predictions), but why do you predict an extra feature?

How do you fix this issue with predict_length=5? Using predict_length = 1 obviously fixes the issue.

YuanxunLu commented 2 years ago

'predict_length' is a decrypted option in my history experiments, which I tried to predict more than one frame's results but I found it useless finally. I always output ONE frame results for each forward pass. Sorry that I didn't carefully check the training-related codes and clean the useless parts.

GMMLogLoss is a generalized GMM loss, which means you can model the target using several Gaussian distributions. In that case, you need to output each probability for each gaussian and choose one gaussian as your distribution at the time. In my case, I only use one gaussian to model the head pose, and therefore the probability for each gaussian is meaningless. That's why you see 25 values (24 + 1). I output this probability to keep the completeness of the GMM loss (but it is useless actually).

Hope the above helps.

icedoom888 commented 2 years ago

Oh I see! It was quite hard to keep track of all the config parameters throughout the code.. Makes sense now, thanks for the help! And nice work 👍