harlanhong / ICCV2023-MCNET

The official code of our ICCV2023 work: Implicit Identity Representation Conditioned Memory Compensation Network for Talking Head video Generation
234 stars 20 forks source link

512p training error #9

Open movingright opened 10 months ago

movingright commented 10 months ago

@harlanhong when I try to extend to 512p training, I get this issue after changing the image size in config file (as you recommended here: https://github.com/harlanhong/ICCV2023-MCNET/issues/4#issuecomment-1664873417). Which other changes should I keep in mind for 512p training?

Traceback (most recent call last):
  File "/home/ubuntu/code/ICCV2023-MCNET/run.py", line 256, in <module>
    train.train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.local_rank,device,opt,writer)
  File "/home/ubuntu/code/ICCV2023-MCNET/train.py", line 119, in train
    losses_generator, generated = generator_full(x,weight,epoch=epoch) 
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.11/site-packages/torch-2.0.1-py3.11-linux-x86_64.egg/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/code/ICCV2023-MCNET/modules/model.py", line 319, in forward
    generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving, source_depth = depth_source, driving_depth = depth_driving,driving_image=x['driving'])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.11/site-packages/torch-2.0.1-py3.11-linux-x86_64.egg/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.11/site-packages/torch-2.0.1-py3.11-linux-x86_64.egg/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.11/site-packages/torch-2.0.1-py3.11-linux-x86_64.egg/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.11/site-packages/torch-2.0.1-py3.11-linux-x86_64.egg/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/code/ICCV2023-MCNET/modules/generator.py", line 497, in forward
    out = self.mbUnit(out,output_dict,keypoints = kp_source['value'])
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.11/site-packages/torch-2.0.1-py3.11-linux-x86_64.egg/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/code/ICCV2023-MCNET/modules/generator.py", line 309, in forward
    feat = eval('self.feat_forward_proj_{}'.format(w))(out_cs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.11/site-packages/torch-2.0.1-py3.11-linux-x86_64.egg/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.11/site-packages/torch-2.0.1-py3.11-linux-x86_64.egg/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda3/lib/python3.11/site-packages/torch-2.0.1-py3.11-linux-x86_64.egg/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Given groups=1, weight of size [512, 128, 1, 1], expected input[8, 256, 64, 64] to have 128 channels, but got 256 channels instead
harlanhong commented 5 months ago

You may be care the part of keypoint detector

RobinROAR commented 4 weeks ago

You may be care the part of keypoint detector

Same error :(
Just modify the frame_shape to [512,512] and use the 512 train dataset will meet the same error. Pls give some clear advice :)
Thanks