FoundationVision / VAR

[NeurIPS 2024 Oral][GPT beats diffusion🔥] [scaling laws in visual generation📈] Official impl. of "Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction". An *ultra-simple, user-friendly yet state-of-the-art* codebase for autoregressive image generation!
MIT License
4.28k stars 315 forks source link

Inference after training on own dataset #58

Closed moeinheidari7829 closed 6 months ago

moeinheidari7829 commented 6 months ago

Deat authors, thank you for open sourcing your great work.

I tried to train a model on my won dataset with depth = 16 and with the following command:

d16, 256x256

torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \ --depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1

However, once the model is trained, I get mismatch errors between keys of the trained model as follows:

/project/st-ilker-1/moein/moein-envs/var-env/lib/python3.8/site-packages/torch/nn/init.py:46: UserWarning: Specified kernel cache directory is not writable! This disables kernel caching. Specified directory is /home/moeinh78/.cache/torch/kernels. This warning will appear only once per process. (Triggered internally at ../aten/src/ATen/native/cuda/jitutils.cpp:1460.) tensor.erfinv() Traceback (most recent call last): File "sample.py", line 35, in var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)

raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(

RuntimeError: Error(s) in loading state_dict for VAR: Missing key(s) in state_dict: "pos_start", "pos_1LC", "lvl_1L", "attn_bias_for_masking", "word_embed.weight", "word_embed.bias", "class_emb.weight", "lvl_embed.weight", "blocks.0.attn.scale_mul_1H11", "blocks.0.attn.q_bias", "blocks.0.attn.v_bias", "blocks.0.attn.zero_k_bias", "blocks.0.attn.mat_qkv.weight", "blocks.0.attn.proj.weight", "blocks.0.attn.proj.bias", "blocks.0.ffn.fc1.weight", "blocks.0.ffn.fc1.bias", "blocks.0.ffn.fc2.weight", "blocks.0.ffn.fc2.bias", "blocks.0.ada_lin.1.weight", "blocks.0.ada_lin.1.bias", "blocks.1.attn.scale_mul_1H11", "blocks.1.attn.q_bias", "blocks.1.attn.v_bias", "blocks.1.attn.zero_k_bias", "blocks.1.attn.mat_qkv.weight", "blocks.1.attn.proj.weight", "blocks.1.attn.proj.bias", "blocks.1.ffn.fc1.weight", "blocks.1.ffn.fc1.bias", "blocks.1.ffn.fc2.weight", "blocks.1.ffn.fc2.bias", "blocks.1.ada_lin.1.weight", "blocks.1.ada_lin.1.bias", "blocks.2.attn.scale_mul_1H11", "blocks.2.attn.q_bias", "blocks.2.attn.v_bias", "blocks.2.attn.zero_k_bias", "blocks.2.attn.mat_qkv.weight", "blocks.2.attn.proj.weight", "blocks.2.attn.proj.bias", "blocks.2.ffn.fc1.weight", "blocks.2.ffn.fc1.bias", "blocks.2.ffn.fc2.weight", "blocks.2.ffn.fc2.bias", "blocks.2.ada_lin.1.weight", "blocks.2.ada_lin.1.bias", "blocks.3.attn.scale_mul_1H11", "blocks.3.attn.q_bias", "blocks.3.attn.v_bias", "blocks.3.attn.zero_k_bias", "blocks.3.attn.mat_qkv.weight", "blocks.3.attn.proj.weight", "blocks.3.attn.proj.bias", "blocks.3.ffn.fc1.weight", "blocks.3.ffn.fc1.bias", "blocks.3.ffn.fc2.weight", "blocks.3.ffn.fc2.bias", "blocks.3.ada_lin.1.weight", "blocks.3.ada_lin.1.bias", "blocks.4.attn.scale_mul_1H11", "blocks.4.attn.q_bias", "blocks.4.attn.v_bias", "blocks.4.attn.zero_k_bias", "blocks.4.attn.mat_qkv.weight", "blocks.4.attn.proj.weight", "blocks.4.attn.proj.bias", "blocks.4.ffn.fc1.weight", "blocks.4.ffn.fc1.bias", "blocks.4.ffn.fc2.weight", "blocks.4.ffn.fc2.bias", "blocks.4.ada_lin.1.weight", "blocks.4.ada_lin.1.bias", "blocks.5.attn.scale_mul_1H11", "blocks.5.attn.q_bias", "blocks.5.attn.v_bias", "blocks.5.attn.zero_k_bias", "blocks.5.attn.mat_qkv.weight", "blocks.5.attn.proj.weight", "blocks.5.attn.proj.bias", "blocks.5.ffn.fc1.weight", "blocks.5.ffn.fc1.bias", "blocks.5.ffn.fc2.weight", "blocks.5.ffn.fc2.bias", "blocks.5.ada_lin.1.weight", "blocks.5.ada_lin.1.bias", "blocks.6.attn.scale_mul_1H11", "blocks.6.attn.q_bias", "blocks.6.attn.v_bias", "blocks.6.attn.zero_k_bias", "blocks.6.attn.mat_qkv.weight", "blocks.6.attn.proj.weight", "blocks.6.attn.proj.bias", "blocks.6.ffn.fc1.weight", "blocks.6.ffn.fc1.bias", "blocks.6.ffn.fc2.weight", "blocks.6.ffn.fc2.bias", "blocks.6.ada_lin.1.weight", "blocks.6.ada_lin.1.bias", "blocks.7.attn.scale_mul_1H11", "blocks.7.attn.q_bias", "blocks.7.attn.v_bias", "blocks.7.attn.zero_k_bias", "blocks.7.attn.mat_qkv.weight", "blocks.7.attn.proj.weight", "blocks.7.attn.proj.bias", "blocks.7.ffn.fc1.weight", "blocks.7.ffn.fc1.bias", "blocks.7.ffn.fc2.weight", "blocks.7.ffn.fc2.bias", "blocks.7.ada_lin.1.weight", "blocks.7.ada_lin.1.bias", "blocks.8.attn.scale_mul_1H11", "blocks.8.attn.q_bias", "blocks.8.attn.v_bias", "blocks.8.attn.zero_k_bias", "blocks.8.attn.mat_qkv.weight", "blocks.8.attn.proj.weight", "blocks.8.attn.proj.bias", "blocks.8.ffn.fc1.weight", "blocks.8.ffn.fc1.bias", "blocks.8.ffn.fc2.weight", "blocks.8.ffn.fc2.bias", "blocks.8.ada_lin.1.weight", "blocks.8.ada_lin.1.bias", "blocks.9.attn.scale_mul_1H11", "blocks.9.attn.q_bias", "blocks.9.attn.v_bias", "blocks.9.attn.zero_k_bias", "blocks.9.attn.mat_qkv.weight", "blocks.9.attn.proj.weight", "blocks.9.attn.proj.bias", "blocks.9.ffn.fc1.weight", "blocks.9.ffn.fc1.bias", "blocks.9.ffn.fc2.weight", "blocks.9.ffn.fc2.bias", "blocks.9.ada_lin.1.weight", "blocks.9.ada_lin.1.bias", "blocks.10.attn.scale_mul_1H11", "blocks.10.attn.q_bias", "blocks.10.attn.v_bias", "blocks.10.attn.zero_k_bias", "blocks.10.attn.mat_qkv.weight", "blocks.10.attn.proj.weight", "blocks.10.attn.proj.bias", "blocks.10.ffn.fc1.weight", "blocks.10.ffn.fc1.bias", "blocks.10.ffn.fc2.weight", "blocks.10.ffn.fc2.bias", "blocks.10.ada_lin.1.weight", "blocks.10.ada_lin.1.bias", "blocks.11.attn.scale_mul_1H11", "blocks.11.attn.q_bias", "blocks.11.attn.v_bias", "blocks.11.attn.zero_k_bias", "blocks.11.attn.mat_qkv.weight", "blocks.11.attn.proj.weight", "blocks.11.attn.proj.bias", "blocks.11.ffn.fc1.weight", "blocks.11.ffn.fc1.bias", "blocks.11.ffn.fc2.weight", "blocks.11.ffn.fc2.bias", "blocks.11.ada_lin.1.weight", "blocks.11.ada_lin.1.bias", "blocks.12.attn.scale_mul_1H11", "blocks.12.attn.q_bias", "blocks.12.attn.v_bias", "blocks.12.attn.zero_k_bias", "blocks.12.attn.mat_qkv.weight", "blocks.12.attn.proj.weight", "blocks.12.attn.proj.bias", "blocks.12.ffn.fc1.weight", "blocks.12.ffn.fc1.bias", "blocks.12.ffn.fc2.weight", "blocks.12.ffn.fc2.bias", "blocks.12.ada_lin.1.weight", "blocks.12.ada_lin.1.bias", "blocks.13.attn.scale_mul_1H11", "blocks.13.attn.q_bias", "blocks.13.attn.v_bias", "blocks.13.attn.zero_k_bias", "blocks.13.attn.mat_qkv.weight", "blocks.13.attn.proj.weight", "blocks.13.attn.proj.bias", "blocks.13.ffn.fc1.weight", "blocks.13.ffn.fc1.bias", "blocks.13.ffn.fc2.weight", "blocks.13.ffn.fc2.bias", "blocks.13.ada_lin.1.weight", "blocks.13.ada_lin.1.bias", "blocks.14.attn.scale_mul_1H11", "blocks.14.attn.q_bias", "blocks.14.attn.v_bias", "blocks.14.attn.zero_k_bias", "blocks.14.attn.mat_qkv.weight", "blocks.14.attn.proj.weight", "blocks.14.attn.proj.bias", "blocks.14.ffn.fc1.weight", "blocks.14.ffn.fc1.bias", "blocks.14.ffn.fc2.weight", "blocks.14.ffn.fc2.bias", "blocks.14.ada_lin.1.weight", "blocks.14.ada_lin.1.bias", "blocks.15.attn.scale_mul_1H11", "blocks.15.attn.q_bias", "blocks.15.attn.v_bias", "blocks.15.attn.zero_k_bias", "blocks.15.attn.mat_qkv.weight", "blocks.15.attn.proj.weight", "blocks.15.attn.proj.bias", "blocks.15.ffn.fc1.weight", "blocks.15.ffn.fc1.bias", "blocks.15.ffn.fc2.weight", "blocks.15.ffn.fc2.bias", "blocks.15.ada_lin.1.weight", "blocks.15.ada_lin.1.bias", "head_nm.ada_lin.1.weight", "head_nm.ada_lin.1.bias", "head.weight", "head.bias". Unexpected key(s) in state_dict: "epoch", "iter", "trainer", "args".

Can you please guide me on this?

mothanaprime commented 6 months ago

Maybe you can try torch.load(var_ckpt, map_location='cpu')['trainer']['var_wo_ddp'] It worked for me.

moeinheidari7829 commented 6 months ago

mothanaprime

Thank you very much for your help. It worked for me as well.

Bests.