xxlong0 / Wonder3D

Single Image to 3D using Cross-Domain Diffusion for 3D Generation
https://www.xxlong.site/Wonder3D/
GNU Affero General Public License v3.0
4.49k stars 351 forks source link

Stage 2 training BUG: following keys are missing #175

Open RenieWell opened 4 weeks ago

RenieWell commented 4 weeks ago

Thanks for sharing this work with us!

I have trained the stage 1 model with expected performance with the Unet weight from the pretrained model pretrained_model_name_or_path: 'lambdalabs/sd-image-variations-diffusers'

But after the training, the model trained in stage 1 cann't be loaded by the stage 2 code, and I got the following errors: load pre-trained unet from ./outputs/wonder3D-mix-vanila/checkpoint/ Traceback (most recent call last): File "/data/.conda/envs/marigold/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/data/.conda/envs/marigold/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module> cli.main() File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main run() File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file runpy.run_path(target, run_name="__main__") File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path return _run_module_code(code, init_globals, run_name, File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "/data/Wonder3D/train_mvdiffusion_joint.py", line 773, in <module> main(cfg) File "/data/Wonder3D/train_mvdiffusion_joint.py", line 251, in main unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs) File "/data/.conda/envs/marigold/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 604, in from_pretrained raise ValueError( ValueError: Cannot load <class 'mvdiffusion.models.unet_mv2d_condition.UNetMV2DConditionModel'> from ./outputs/wonder3D-mix-vanila/checkpoint/ because the following keys are missing: up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.1.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.1.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.1.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.2.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.1.transformer_blocks.0.norm_joint_mid.bias, down_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.0.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, down_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight. Please make sure to passlow_cpu_mem_usage=Falseanddevice_map=Noneif you want to randomly initialize those weights or else make sure your checkpoint file is correct.

According to my understanding, the missing weights should be trained in stage 2, so they cann't be loaded in before the stage 2's training. Do you have any ideas about this?

SunzeY commented 2 weeks ago

similar problem!

bbbbubble commented 5 days ago

same problem. @RenieWell @SunzeY @xxlong0 Have you solved this?

bbbbubble commented 1 day ago

If add low_cpu_mem_usage=False to from_pretrained() function, or use from_pretrained_2d() function instead, it will start running successfully, but won't converge...

image