xxlong0 / Wonder3D

Single Image to 3D using Cross-Domain Diffusion for 3D Generation
https://www.xxlong.site/Wonder3D/
GNU Affero General Public License v3.0
4.82k stars 387 forks source link

Stage 2 training BUG: following keys are missing #175

Open RenieWell opened 6 months ago

RenieWell commented 6 months ago

Thanks for sharing this work with us!

I have trained the stage 1 model with expected performance with the Unet weight from the pretrained model pretrained_model_name_or_path: 'lambdalabs/sd-image-variations-diffusers'

But after the training, the model trained in stage 1 cann't be loaded by the stage 2 code, and I got the following errors: load pre-trained unet from ./outputs/wonder3D-mix-vanila/checkpoint/ Traceback (most recent call last): File "/data/.conda/envs/marigold/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/data/.conda/envs/marigold/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module> cli.main() File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main run() File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file runpy.run_path(target, run_name="__main__") File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path return _run_module_code(code, init_globals, run_name, File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "/data/Wonder3D/train_mvdiffusion_joint.py", line 773, in <module> main(cfg) File "/data/Wonder3D/train_mvdiffusion_joint.py", line 251, in main unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs) File "/data/.conda/envs/marigold/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 604, in from_pretrained raise ValueError( ValueError: Cannot load <class 'mvdiffusion.models.unet_mv2d_condition.UNetMV2DConditionModel'> from ./outputs/wonder3D-mix-vanila/checkpoint/ because the following keys are missing: up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.1.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.1.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.1.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.2.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.1.transformer_blocks.0.norm_joint_mid.bias, down_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.0.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, down_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight. Please make sure to passlow_cpu_mem_usage=Falseanddevice_map=Noneif you want to randomly initialize those weights or else make sure your checkpoint file is correct.

According to my understanding, the missing weights should be trained in stage 2, so they cann't be loaded in before the stage 2's training. Do you have any ideas about this?

SunzeY commented 5 months ago

similar problem!

bbbbubble commented 5 months ago

same problem. @RenieWell @SunzeY @xxlong0 Have you solved this?

bbbbubble commented 5 months ago

If add low_cpu_mem_usage=False to from_pretrained() function, or use from_pretrained_2d() function instead, it will start running successfully, but will be hard to converge...

image

yyuezhi commented 3 months ago

Hi, I am runing in to the same problem. Have you solved this? @bbbbubble @SunzeY @RenieWell

mengxuyiGit commented 3 months ago

maybe try replacing "from_pretrained" to "from_pretrained_2d" in https://github.com/xxlong0/Wonder3D/blob/deeba9833570fce09dd4da393f6318475e85a735/train_mvdiffusion_joint.py#L251 ?

liuyifan22 commented 1 month ago

@bbbbubble @mengxuyiGit @RenieWell @SunzeY

Hello, if the problem persists, maybe you can try my solution: change the cd_attention_mid attribute in File "./configs/train/stage1-mix-6views-lvis.yaml " from false to true,

and change line 237 in ./train_mvdiffusion_image.py to

unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs).

It works for me. Seems that it is not a problem with the second stage but the first. Pls inform me if anything goes up

bbbbubble commented 1 month ago

@bbbbubble @mengxuyiGit @RenieWell @SunzeY

Hello, if the problem persists, maybe you can try my solution: change the cd_attention_mid attribute in File "./configs/train/stage1-mix-6views-lvis.yaml " from false to true,

and change line 237 in ./train_mvdiffusion_image.py to

unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs).

It works for me. Seems that it is not a problem with the second stage but the first. Pls inform me if anything goes up

Seems not work. "Missing Keys" error will show up in the first stage:

[rank0]: ValueError: Cannot load <class 'mvdiffusion.models.unet_mv2d_condition.UNetMV2DConditionModel'> from /home/azhe.cp/avatar_utils/sd-image-variations-diffusers because the following keys are missing: [rank0]: mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.bias, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias. [rank0]: Please make sure to pass low_cpu_mem_usage=False and device_map=None if you want to randomly initialize those weights or else make sure your checkpoint file is correct.

liuyifan22 commented 1 month ago

Sorry, I was using Wonder3D's checkpoint as input for my 1st stage training, and that is fit. If you want to train from lambdalabs/sd-image-variations-diffusers, my method will not be working. Sorry for the confusion.