soumik-kanad / diff2lip

Other
322 stars 38 forks source link

t, the shape in current model is torch.Size([256]). size mismatch for out.2.weight: copying a param with shape torch.Size([6, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 128, 3, 3]). #2

Closed constan1 closed 11 months ago

constan1 commented 11 months ago

Parameter missmatch when using the checkpoint. I am running

python generate.py --video_path "test.mp4" --audio_path "InputAudio/test_audio.mp3" --model_path "checkpoint/e7.24.1.3_model260000_paper.pt" --out_path "OutputVideo/output.mp4"

All the weights seem to have size mismatch when copying over from the checkpoint...

constan1 commented 11 months ago

I also get

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for TFGModel: Missing key(s) in state_dict: "input_blocks.3.0.op.weight", "input_blocks.3.0.op.bias"...

All the weights seem to be missing. I downloaded the checkpoint.

soumik-kanad commented 11 months ago

I think you might have missed the rest of the flags. I checked again and single video inference also works for me.

You still need to run the same command as scripts/inference.sh with all the rest of the flags unchanged -

python generate.py $MODEL_FLAGS  $DIFFUSION_FLAGS  $SAMPLE_FLAGS $DATA_FLAGS $TFG_FLAGS $GEN_FLAGS

just that now your SAMPLE_FLAGS and GEN_FLAGS will change from

SAMPLE_FLAGS="--sampling_seed=7   $sample_input_flags --timestep_respacing ddim25 --use_ddim True --model_path=$model_path --sample_path=$sample_path"
GEN_FLAGS="--generate_from_filelist 1 --test_video_dir=$test_video_dir --filelist=$filelist --save_orig=False --face_det_batch_size 64 --pads 0,0,0,0"

to

SAMPLE_FLAGS="--sampling_seed=7   $sample_input_flags --timestep_respacing ddim25 --use_ddim True --model_path=$model_path"
GEN_FLAGS="--generate_from_filelist $generate_from_filelist  --video_path=$video_path --audio_path=$audio_path --out_path=$out_path --save_orig=False --face_det_batch_size 64 --pads 0,0,0,0"

where the variables video_path, audio_path,out_path,model_path need to be defined accordingly

generate_from_filelist=0
video_path="path/to/video.mp4"
audio_path="path/to/audio.mp4"
out_path="path/to/output.mp4"
model_path="path/to/model.pt"

and sample_input_flags="--sampling_input_type=gt --sampling_ref_type=gt" for cross setting.