bowang-lab / scGPT

https://scgpt.readthedocs.io/en/latest/
MIT License
920 stars 167 forks source link

Loading model state dict error when loading recommended human pretrained. #173

Open Yonggie opened 3 months ago

Yonggie commented 3 months ago
# in purtabation toturial
....
model_dict = model.state_dict()
pretrained_dict = torch.load(model_file)
pretrained_dict = {
    k: v
    for k, v in pretrained_dict.items()
    if any([k.startswith(prefix) for prefix in load_param_prefixs])
}
for k, v in pretrained_dict.items():
    logger.info(f"Loading params {k} with shape {v.shape}")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)  # exception here
....
Exception has occurred: RuntimeError
Error(s) in loading state_dict for TransformerGenerator:
    Unexpected key(s) in state_dict: "transformer_encoder.layers.0.self_attn.Wqkv.weight", "transformer_encoder.layers.0.self_attn.Wqkv.bias", "transformer_encoder.layers.1.self_attn.Wqkv.weight", "transformer_encoder.layers.1.self_attn.Wqkv.bias", "transformer_encoder.layers.2.self_attn.Wqkv.weight", "transformer_encoder.layers.2.self_attn.Wqkv.bias", "transformer_encoder.layers.3.self_attn.Wqkv.weight", "transformer_encoder.layers.3.self_attn.Wqkv.bias", "transformer_encoder.layers.4.self_attn.Wqkv.weight", "transformer_encoder.layers.4.self_attn.Wqkv.bias", "transformer_encoder.layers.5.self_attn.Wqkv.weight", "transformer_encoder.layers.5.self_attn.Wqkv.bias", "transformer_encoder.layers.6.self_attn.Wqkv.weight", "transformer_encoder.layers.6.self_attn.Wqkv.bias", "transformer_encoder.layers.7.self_attn.Wqkv.weight", "transformer_encoder.layers.7.self_attn.Wqkv.bias", "transformer_encoder.layers.8.self_attn.Wqkv.weight", "transformer_encoder.layers.8.self_attn.Wqkv.bias", "transformer_encoder.layers.9.self_attn.Wqkv.weight", "transformer_encoder.layers.9.self_attn.Wqkv.bias", "transformer_encoder.layers.10.self_attn.Wqkv.weight", "transformer_encoder.layers.10.self_attn.Wqkv.bias", "transformer_encoder.layers.11.self_attn.Wqkv.weight", "transformer_encoder.layers.11.self_attn.Wqkv.bias". 
Nik212 commented 3 months ago

Same issue. Are there any solutions to that?

anaistrate commented 2 months ago

FYI I was running into the same issue and I fixed it by installing flash-attn==1.0.2

Nik212 commented 2 months ago

FYI I was running into the same issue and I fixed it by installing flash-attn==1.0.2

Which pytorch-related libraries versions you have, wouldn't you mind sharing please? Did you just use pip install scgpt "flash-attn==1.0.2"

Qotov commented 2 months ago

This solved my problem with the error https://github.com/bowang-lab/scGPT/issues/153#issuecomment-1950228683