Tencent / HunyuanDiT

Hunyuan-DiT : A Powerful Multi-Resolution Diffusion Transformer with Fine-Grained Chinese Understanding
https://dit.hunyuan.tencent.com/
Other
3.32k stars 285 forks source link

How to use the full-parameter training checkpoint for sample_t2i.py? #107

Open a5704607 opened 3 months ago

a5704607 commented 3 months ago

PYTHONPATH=./ sh hydit/train.sh --index-file dataset/porcelain/jsons/porcelain.json I use this code to train full-parameter, then the checkpoint is saved as 005-dit_g2_full_1024p/checkpoints,I want to know how to use the checkpoint for sample_t2i.py?

MichaelFan01 commented 3 months ago

同问

a5704607 commented 3 months ago

同问

I think you can use 005-dit_g2_full_1024p/checkpoints/0010000.pt/mp_rank_00_model_states.pt

MichaelFan01 commented 3 months ago

同问

I think you can use 005-dit_g2_full_1024p/checkpoints/0010000.pt/mp_rank_00_model_states.pt

试了试,不太行啊: File "/maindata/data/shared/multimodal/michael.fan/aigc-apps/HunyuanDiT-main/hydit/inference.py", line 223, in init self.model.load_state_dict(state_dict) File "/maindata/data/shared/multimodal/michael.fan/env/miniconda3/envs/hunyuanDit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for HunYuanDiT: Missing key(s) in state_dict: "text_embedding_padding", "mlp_t5.0.weight", "mlp_t5.0.bias", "mlp_t5.2.weight", "mlp_t5.2.bias", "pooler.positional_embedding", "pooler.k_proj.weight", "pooler.k_proj.bias", "pooler.q_proj.weight", "pooler.q_proj.bias", "pooler.v_proj.weight", "pooler.v_proj.bias", "pooler.c_proj.weight", "pooler.c_proj.bias", "style_embedder.weight", "x_embedder.proj.weight", "x_embedder.proj.bias", "t_embedder.mlp.0.weight", "t_embedder.mlp.0.bias", "t_embedder.mlp.2.weight", "t_embedder.mlp.2.bias", "extra_embedder.0.weight", "extra_embedder.0.bias", "extra_embedder.2.weight", "extra_embedder.2.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.attn1.Wqkv.weight", "blocks.0.attn1.Wqkv.bias", "blocks.0.attn1.q_norm.weight", "blocks.0.attn1.q_norm.bias", "blocks.0.attn1.k_norm.weight", "blocks.0.attn1.k_norm.bias", "blocks.0.attn1.out_proj.weight", "blocks.0.attn1.out_proj.bias", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.0.default_modulation.1.weight", "blocks.0.default_modulation.1.bias", "blocks.0.attn2.q_proj.weight", "blocks.0.attn2.q_proj.bias", "blocks.0.attn2.kv_proj.weight", "blocks.0.attn2.kv_proj.bias", "blocks.0.attn2.q_norm.weight", "blocks.0.attn2.q_norm.bias", "blocks.0.attn2.k_norm.weight", "blocks.0.attn2.k_norm.bias", "blocks.0.attn2.out_proj.weight", "blocks.0.attn2.out_proj.bias", "blocks.0.norm3.weight", "blocks.0.norm3.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.attn1.Wqkv.weight", "blocks.1.attn1.Wqkv.bias", "blocks.1.attn1.q_norm.weight", "blocks.1.attn1.q_norm.bias", "blocks.1.attn1.k_norm.weight", "blocks.1.attn1.k_norm.bias", "blocks.1.attn1.out_proj.weight", "blocks.1.attn1.out_proj.bias", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.1.default_modulation.1.weight", "blocks.1.default_modulation.1.bias", "blocks.1.attn2.q_proj.weight", "blocks.1.attn2.q_proj.bias", "blocks.1.attn2.kv_proj.weight", "blocks.1.attn2.kv_proj.bias", "blocks.1.attn2.q_norm.weight", "blocks.1.attn2.q_norm.bias", "blocks.1.attn2.k_norm.weight", "blocks.1.attn2.k_norm.bias", "blocks.1.attn2.out_proj.weight", "blocks.1.attn2.out_proj.bias", "blocks.1.norm3.weight", "blocks.1.norm3.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.attn1.Wqkv.weight", "blocks.2.attn1.Wqkv.bias", "blocks.2.attn1.q_norm.weight", "blocks.2.attn1.q_norm.bias", "blocks.2.attn1.k_norm.weight", "blocks.2.attn1.k_norm.bias", "blocks.2.attn1.out_proj.weight", "blocks.2.attn1.out_proj.bias"

a5704607 commented 3 months ago

同问

I think you can use 005-dit_g2_full_1024p/checkpoints/0010000.pt/mp_rank_00_model_states.pt

试了试,不太行啊: File "/maindata/data/shared/multimodal/michael.fan/aigc-apps/HunyuanDiT-main/hydit/inference.py", line 223, in init self.model.load_state_dict(state_dict) File "/maindata/data/shared/multimodal/michael.fan/env/miniconda3/envs/hunyuanDit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for HunYuanDiT: Missing key(s) in state_dict: "text_embedding_padding", "mlp_t5.0.weight", "mlp_t5.0.bias", "mlp_t5.2.weight", "mlp_t5.2.bias", "pooler.positional_embedding", "pooler.k_proj.weight", "pooler.k_proj.bias", "pooler.q_proj.weight", "pooler.q_proj.bias", "pooler.v_proj.weight", "pooler.v_proj.bias", "pooler.c_proj.weight", "pooler.c_proj.bias", "style_embedder.weight", "x_embedder.proj.weight", "x_embedder.proj.bias", "t_embedder.mlp.0.weight", "t_embedder.mlp.0.bias", "t_embedder.mlp.2.weight", "t_embedder.mlp.2.bias", "extra_embedder.0.weight", "extra_embedder.0.bias", "extra_embedder.2.weight", "extra_embedder.2.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.attn1.Wqkv.weight", "blocks.0.attn1.Wqkv.bias", "blocks.0.attn1.q_norm.weight", "blocks.0.attn1.q_norm.bias", "blocks.0.attn1.k_norm.weight", "blocks.0.attn1.k_norm.bias", "blocks.0.attn1.out_proj.weight", "blocks.0.attn1.out_proj.bias", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.0.default_modulation.1.weight", "blocks.0.default_modulation.1.bias", "blocks.0.attn2.q_proj.weight", "blocks.0.attn2.q_proj.bias", "blocks.0.attn2.kv_proj.weight", "blocks.0.attn2.kv_proj.bias", "blocks.0.attn2.q_norm.weight", "blocks.0.attn2.q_norm.bias", "blocks.0.attn2.k_norm.weight", "blocks.0.attn2.k_norm.bias", "blocks.0.attn2.out_proj.weight", "blocks.0.attn2.out_proj.bias", "blocks.0.norm3.weight", "blocks.0.norm3.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.attn1.Wqkv.weight", "blocks.1.attn1.Wqkv.bias", "blocks.1.attn1.q_norm.weight", "blocks.1.attn1.q_norm.bias", "blocks.1.attn1.k_norm.weight", "blocks.1.attn1.k_norm.bias", "blocks.1.attn1.out_proj.weight", "blocks.1.attn1.out_proj.bias", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.1.default_modulation.1.weight", "blocks.1.default_modulation.1.bias", "blocks.1.attn2.q_proj.weight", "blocks.1.attn2.q_proj.bias", "blocks.1.attn2.kv_proj.weight", "blocks.1.attn2.kv_proj.bias", "blocks.1.attn2.q_norm.weight", "blocks.1.attn2.q_norm.bias", "blocks.1.attn2.k_norm.weight", "blocks.1.attn2.k_norm.bias", "blocks.1.attn2.out_proj.weight", "blocks.1.attn2.out_proj.bias", "blocks.1.norm3.weight", "blocks.1.norm3.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.attn1.Wqkv.weight", "blocks.2.attn1.Wqkv.bias", "blocks.2.attn1.q_norm.weight", "blocks.2.attn1.q_norm.bias", "blocks.2.attn1.k_norm.weight", "blocks.2.attn1.k_norm.bias", "blocks.2.attn1.out_proj.weight", "blocks.2.attn1.out_proj.bias"

key: 'ema'

MichaelFan01 commented 3 months ago

同问

I think you can use 005-dit_g2_full_1024p/checkpoints/0010000.pt/mp_rank_00_model_states.pt

试了试,不太行啊: File "/maindata/data/shared/multimodal/michael.fan/aigc-apps/HunyuanDiT-main/hydit/inference.py", line 223, in init self.model.load_state_dict(state_dict) File "/maindata/data/shared/multimodal/michael.fan/env/miniconda3/envs/hunyuanDit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for HunYuanDiT: Missing key(s) in state_dict: "text_embedding_padding", "mlp_t5.0.weight", "mlp_t5.0.bias", "mlp_t5.2.weight", "mlp_t5.2.bias", "pooler.positional_embedding", "pooler.k_proj.weight", "pooler.k_proj.bias", "pooler.q_proj.weight", "pooler.q_proj.bias", "pooler.v_proj.weight", "pooler.v_proj.bias", "pooler.c_proj.weight", "pooler.c_proj.bias", "style_embedder.weight", "x_embedder.proj.weight", "x_embedder.proj.bias", "t_embedder.mlp.0.weight", "t_embedder.mlp.0.bias", "t_embedder.mlp.2.weight", "t_embedder.mlp.2.bias", "extra_embedder.0.weight", "extra_embedder.0.bias", "extra_embedder.2.weight", "extra_embedder.2.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.attn1.Wqkv.weight", "blocks.0.attn1.Wqkv.bias", "blocks.0.attn1.q_norm.weight", "blocks.0.attn1.q_norm.bias", "blocks.0.attn1.k_norm.weight", "blocks.0.attn1.k_norm.bias", "blocks.0.attn1.out_proj.weight", "blocks.0.attn1.out_proj.bias", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.0.default_modulation.1.weight", "blocks.0.default_modulation.1.bias", "blocks.0.attn2.q_proj.weight", "blocks.0.attn2.q_proj.bias", "blocks.0.attn2.kv_proj.weight", "blocks.0.attn2.kv_proj.bias", "blocks.0.attn2.q_norm.weight", "blocks.0.attn2.q_norm.bias", "blocks.0.attn2.k_norm.weight", "blocks.0.attn2.k_norm.bias", "blocks.0.attn2.out_proj.weight", "blocks.0.attn2.out_proj.bias", "blocks.0.norm3.weight", "blocks.0.norm3.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.attn1.Wqkv.weight", "blocks.1.attn1.Wqkv.bias", "blocks.1.attn1.q_norm.weight", "blocks.1.attn1.q_norm.bias", "blocks.1.attn1.k_norm.weight", "blocks.1.attn1.k_norm.bias", "blocks.1.attn1.out_proj.weight", "blocks.1.attn1.out_proj.bias", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.1.default_modulation.1.weight", "blocks.1.default_modulation.1.bias", "blocks.1.attn2.q_proj.weight", "blocks.1.attn2.q_proj.bias", "blocks.1.attn2.kv_proj.weight", "blocks.1.attn2.kv_proj.bias", "blocks.1.attn2.q_norm.weight", "blocks.1.attn2.q_norm.bias", "blocks.1.attn2.k_norm.weight", "blocks.1.attn2.k_norm.bias", "blocks.1.attn2.out_proj.weight", "blocks.1.attn2.out_proj.bias", "blocks.1.norm3.weight", "blocks.1.norm3.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.attn1.Wqkv.weight", "blocks.2.attn1.Wqkv.bias", "blocks.2.attn1.q_norm.weight", "blocks.2.attn1.q_norm.bias", "blocks.2.attn1.k_norm.weight", "blocks.2.attn1.k_norm.bias", "blocks.2.attn1.out_proj.weight", "blocks.2.attn1.out_proj.bias"

key: 'ema'

多谢,搞定了

jonathanyin12 commented 3 months ago

Running into same issue. What do you mean by this?

key: 'ema'

Is there some way to simply point to the checkpoint path using the command line?

jonathanyin12 commented 3 months ago

Figured it out. After loading the weights from 'mp_rank_00_model_states.pt', you have to index it with 'ema' before loading the state dict. Code ref