DinoMan / speech-driven-animation

949 stars 289 forks source link

Runtime error when loading state_dict for Generator #70

Closed MrMho closed 1 year ago

MrMho commented 1 year ago

Hi,

running this code using Python 3.7

# run_example.py

import sda
import scipy.io.wavfile as wav
from PIL import Image

va = sda.VideoAnimator(gpu=-1, model_path="timit") # Instantiate the animator
fs, audio_clip = wav.read("example/audio.wav")
still_frame = Image.open("example/image.bmp")
vid, aud = va(still_frame, audio_clip, fs=fs)

I get the following error when trying to create an instance of sda.VideoAnimator:

Traceback (most recent call last):
  File "n:/[...]/speech-driven-animation/run_example.py", line 5, in <module>
    va = sda.VideoAnimator(gpu=-1, model_path="timit") # Instantiate the animator
  File "n:\[...]\speech-driven-animation\sda\sda.py", line 151, in __init__
    self.generator.load_state_dict(model_dict['generator'])
  File "C:\Users\[...]\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1498, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Generator:
        Missing key(s) in state_dict: "aux_dcl.0.weight", "aux_dcl.1.weight", "aux_dcl.1.bias", "aux_dcl.1.running_mean", "aux_dcl.1.running_var".
        size mismatch for dcl.0.0.weight: copying a param with shape torch.Size([394, 1024, 4, 3]) from checkpoint, the shape in current model is torch.Size([384, 1024, 4, 3]).

The same error occurs when I try to use the other models (crema, grid).

Do you know why that is?

Thanks in advance.

DinoMan commented 1 year ago

Sorry, it looks like we made some mistakes when cleaning up the code. Will update master now to fix this

DinoMan commented 1 year ago

It should be fixed now sorry for the inconvenience. Please let us know if it has worked by closing the issue

MrMho commented 1 year ago

Thanks a lot for the quick reply and the fix!

It is not quite working yet. The error I mentioned above disappeared but now there is another problem:

Traceback (most recent call last):
  File "n:/[...]/speech-driven-animation/run_example.py", line 9, in <module>
    vid, aud = va(still_frame, audio_clip, fs=fs)
  File "n:\[...]\speech-driven-animation\sda\sda.py", line 229, in __call__
    frame = self.preprocess_img(frame)
  File "n:\[...]\speech-driven-animation\sda\sda.py", line 191, in preprocess_img
    src = self.fa.get_landmarks(img)[0][self.stablePntsIDs, :]
  File "C:\Users\[...]\anaconda3\lib\site-packages\face_alignment\api.py", line 110, in get_landmarks
    return self.get_landmarks_from_image(image_or_path, detected_faces, return_bboxes, return_landmark_score)
  File "C:\Users\[...]\anaconda3\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\[...]\anaconda3\lib\site-packages\face_alignment\api.py", line 170, in get_landmarks_from_image
    pts, pts_img, scores = get_preds_fromhm(out, center.numpy(), scale)
  File "C:\Users\[...]\anaconda3\lib\site-packages\face_alignment\utils.py", line 201, in get_preds_fromhm
    preds, preds_orig = _get_preds_fromhm(hm, idx, center, scale)
SystemError: CPUDispatcher(<function _get_preds_fromhm at 0x0000024A31BBA9D8>) returned a result with an error set
(base)

By the way, in the example from my first post, I had to convert the image to a numpy array in order to get sda.utils.get_image to work.

MrMho commented 1 year ago

I got it working in a fresh virtual environment. I guess I was using outdated package versions; cannot tell which package caused the trouble, though.