However, once the model is trained, I get mismatch errors between keys of the trained model as follows:
/project/st-ilker-1/moein/moein-envs/var-env/lib/python3.8/site-packages/torch/nn/init.py:46: UserWarning: Specified kernel cache directory is not writable! This disables kernel caching. Specified directory is /home/moeinh78/.cache/torch/kernels. This warning will appear only once per process. (Triggered internally at ../aten/src/ATen/native/cuda/jitutils.cpp:1460.)
tensor.erfinv()
Traceback (most recent call last):
File "sample.py", line 35, in
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
Deat authors, thank you for open sourcing your great work.
I tried to train a model on my won dataset with depth = 16 and with the following command:
d16, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \ --depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
However, once the model is trained, I get mismatch errors between keys of the trained model as follows:
/project/st-ilker-1/moein/moein-envs/var-env/lib/python3.8/site-packages/torch/nn/init.py:46: UserWarning: Specified kernel cache directory is not writable! This disables kernel caching. Specified directory is /home/moeinh78/.cache/torch/kernels. This warning will appear only once per process. (Triggered internally at ../aten/src/ATen/native/cuda/jitutils.cpp:1460.) tensor.erfinv() Traceback (most recent call last): File "sample.py", line 35, in
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
RuntimeError: Error(s) in loading state_dict for VAR: Missing key(s) in state_dict: "pos_start", "pos_1LC", "lvl_1L", "attn_bias_for_masking", "word_embed.weight", "word_embed.bias", "class_emb.weight", "lvl_embed.weight", "blocks.0.attn.scale_mul_1H11", "blocks.0.attn.q_bias", "blocks.0.attn.v_bias", "blocks.0.attn.zero_k_bias", "blocks.0.attn.mat_qkv.weight", "blocks.0.attn.proj.weight", "blocks.0.attn.proj.bias", "blocks.0.ffn.fc1.weight", "blocks.0.ffn.fc1.bias", "blocks.0.ffn.fc2.weight", "blocks.0.ffn.fc2.bias", "blocks.0.ada_lin.1.weight", "blocks.0.ada_lin.1.bias", "blocks.1.attn.scale_mul_1H11", "blocks.1.attn.q_bias", "blocks.1.attn.v_bias", "blocks.1.attn.zero_k_bias", "blocks.1.attn.mat_qkv.weight", "blocks.1.attn.proj.weight", "blocks.1.attn.proj.bias", "blocks.1.ffn.fc1.weight", "blocks.1.ffn.fc1.bias", "blocks.1.ffn.fc2.weight", "blocks.1.ffn.fc2.bias", "blocks.1.ada_lin.1.weight", "blocks.1.ada_lin.1.bias", "blocks.2.attn.scale_mul_1H11", "blocks.2.attn.q_bias", "blocks.2.attn.v_bias", "blocks.2.attn.zero_k_bias", "blocks.2.attn.mat_qkv.weight", "blocks.2.attn.proj.weight", "blocks.2.attn.proj.bias", "blocks.2.ffn.fc1.weight", "blocks.2.ffn.fc1.bias", "blocks.2.ffn.fc2.weight", "blocks.2.ffn.fc2.bias", "blocks.2.ada_lin.1.weight", "blocks.2.ada_lin.1.bias", "blocks.3.attn.scale_mul_1H11", "blocks.3.attn.q_bias", "blocks.3.attn.v_bias", "blocks.3.attn.zero_k_bias", "blocks.3.attn.mat_qkv.weight", "blocks.3.attn.proj.weight", "blocks.3.attn.proj.bias", "blocks.3.ffn.fc1.weight", "blocks.3.ffn.fc1.bias", "blocks.3.ffn.fc2.weight", "blocks.3.ffn.fc2.bias", "blocks.3.ada_lin.1.weight", "blocks.3.ada_lin.1.bias", "blocks.4.attn.scale_mul_1H11", "blocks.4.attn.q_bias", "blocks.4.attn.v_bias", "blocks.4.attn.zero_k_bias", "blocks.4.attn.mat_qkv.weight", "blocks.4.attn.proj.weight", "blocks.4.attn.proj.bias", "blocks.4.ffn.fc1.weight", "blocks.4.ffn.fc1.bias", "blocks.4.ffn.fc2.weight", "blocks.4.ffn.fc2.bias", "blocks.4.ada_lin.1.weight", "blocks.4.ada_lin.1.bias", "blocks.5.attn.scale_mul_1H11", "blocks.5.attn.q_bias", "blocks.5.attn.v_bias", "blocks.5.attn.zero_k_bias", "blocks.5.attn.mat_qkv.weight", "blocks.5.attn.proj.weight", "blocks.5.attn.proj.bias", "blocks.5.ffn.fc1.weight", "blocks.5.ffn.fc1.bias", "blocks.5.ffn.fc2.weight", "blocks.5.ffn.fc2.bias", "blocks.5.ada_lin.1.weight", "blocks.5.ada_lin.1.bias", "blocks.6.attn.scale_mul_1H11", "blocks.6.attn.q_bias", "blocks.6.attn.v_bias", "blocks.6.attn.zero_k_bias", "blocks.6.attn.mat_qkv.weight", "blocks.6.attn.proj.weight", "blocks.6.attn.proj.bias", "blocks.6.ffn.fc1.weight", "blocks.6.ffn.fc1.bias", "blocks.6.ffn.fc2.weight", "blocks.6.ffn.fc2.bias", "blocks.6.ada_lin.1.weight", "blocks.6.ada_lin.1.bias", "blocks.7.attn.scale_mul_1H11", "blocks.7.attn.q_bias", "blocks.7.attn.v_bias", "blocks.7.attn.zero_k_bias", "blocks.7.attn.mat_qkv.weight", "blocks.7.attn.proj.weight", "blocks.7.attn.proj.bias", "blocks.7.ffn.fc1.weight", "blocks.7.ffn.fc1.bias", "blocks.7.ffn.fc2.weight", "blocks.7.ffn.fc2.bias", "blocks.7.ada_lin.1.weight", "blocks.7.ada_lin.1.bias", "blocks.8.attn.scale_mul_1H11", "blocks.8.attn.q_bias", "blocks.8.attn.v_bias", "blocks.8.attn.zero_k_bias", "blocks.8.attn.mat_qkv.weight", "blocks.8.attn.proj.weight", "blocks.8.attn.proj.bias", "blocks.8.ffn.fc1.weight", "blocks.8.ffn.fc1.bias", "blocks.8.ffn.fc2.weight", "blocks.8.ffn.fc2.bias", "blocks.8.ada_lin.1.weight", "blocks.8.ada_lin.1.bias", "blocks.9.attn.scale_mul_1H11", "blocks.9.attn.q_bias", "blocks.9.attn.v_bias", "blocks.9.attn.zero_k_bias", "blocks.9.attn.mat_qkv.weight", "blocks.9.attn.proj.weight", "blocks.9.attn.proj.bias", "blocks.9.ffn.fc1.weight", "blocks.9.ffn.fc1.bias", "blocks.9.ffn.fc2.weight", "blocks.9.ffn.fc2.bias", "blocks.9.ada_lin.1.weight", "blocks.9.ada_lin.1.bias", "blocks.10.attn.scale_mul_1H11", "blocks.10.attn.q_bias", "blocks.10.attn.v_bias", "blocks.10.attn.zero_k_bias", "blocks.10.attn.mat_qkv.weight", "blocks.10.attn.proj.weight", "blocks.10.attn.proj.bias", "blocks.10.ffn.fc1.weight", "blocks.10.ffn.fc1.bias", "blocks.10.ffn.fc2.weight", "blocks.10.ffn.fc2.bias", "blocks.10.ada_lin.1.weight", "blocks.10.ada_lin.1.bias", "blocks.11.attn.scale_mul_1H11", "blocks.11.attn.q_bias", "blocks.11.attn.v_bias", "blocks.11.attn.zero_k_bias", "blocks.11.attn.mat_qkv.weight", "blocks.11.attn.proj.weight", "blocks.11.attn.proj.bias", "blocks.11.ffn.fc1.weight", "blocks.11.ffn.fc1.bias", "blocks.11.ffn.fc2.weight", "blocks.11.ffn.fc2.bias", "blocks.11.ada_lin.1.weight", "blocks.11.ada_lin.1.bias", "blocks.12.attn.scale_mul_1H11", "blocks.12.attn.q_bias", "blocks.12.attn.v_bias", "blocks.12.attn.zero_k_bias", "blocks.12.attn.mat_qkv.weight", "blocks.12.attn.proj.weight", "blocks.12.attn.proj.bias", "blocks.12.ffn.fc1.weight", "blocks.12.ffn.fc1.bias", "blocks.12.ffn.fc2.weight", "blocks.12.ffn.fc2.bias", "blocks.12.ada_lin.1.weight", "blocks.12.ada_lin.1.bias", "blocks.13.attn.scale_mul_1H11", "blocks.13.attn.q_bias", "blocks.13.attn.v_bias", "blocks.13.attn.zero_k_bias", "blocks.13.attn.mat_qkv.weight", "blocks.13.attn.proj.weight", "blocks.13.attn.proj.bias", "blocks.13.ffn.fc1.weight", "blocks.13.ffn.fc1.bias", "blocks.13.ffn.fc2.weight", "blocks.13.ffn.fc2.bias", "blocks.13.ada_lin.1.weight", "blocks.13.ada_lin.1.bias", "blocks.14.attn.scale_mul_1H11", "blocks.14.attn.q_bias", "blocks.14.attn.v_bias", "blocks.14.attn.zero_k_bias", "blocks.14.attn.mat_qkv.weight", "blocks.14.attn.proj.weight", "blocks.14.attn.proj.bias", "blocks.14.ffn.fc1.weight", "blocks.14.ffn.fc1.bias", "blocks.14.ffn.fc2.weight", "blocks.14.ffn.fc2.bias", "blocks.14.ada_lin.1.weight", "blocks.14.ada_lin.1.bias", "blocks.15.attn.scale_mul_1H11", "blocks.15.attn.q_bias", "blocks.15.attn.v_bias", "blocks.15.attn.zero_k_bias", "blocks.15.attn.mat_qkv.weight", "blocks.15.attn.proj.weight", "blocks.15.attn.proj.bias", "blocks.15.ffn.fc1.weight", "blocks.15.ffn.fc1.bias", "blocks.15.ffn.fc2.weight", "blocks.15.ffn.fc2.bias", "blocks.15.ada_lin.1.weight", "blocks.15.ada_lin.1.bias", "head_nm.ada_lin.1.weight", "head_nm.ada_lin.1.bias", "head.weight", "head.bias". Unexpected key(s) in state_dict: "epoch", "iter", "trainer", "args".
Can you please guide me on this?