bowang-lab / scGPT

https://scgpt.readthedocs.io/en/latest/
MIT License
1.05k stars 207 forks source link

Cell Type Annotation scGPT_Human model state dict not correct #235

Open ManuelSokolov opened 3 months ago

ManuelSokolov commented 3 months ago

Hi, thank you for your amazing work developing scGPT.

I am following the TutorialAnnotation.ipynb tutorial using the ms dataset from https://drive.google.com/drive/folders/1Qd42YNabzyr2pWt9xoY4cVMTAxsNBt4v and the model from https://drive.google.com/drive/folders/1oWh-ZRdhtoGQ2Fw24HP41FgLoomVo-y

For this section of the tutorial (model loading):

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ntokens = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=3,
    n_cls=num_types if CLS else 1,
    vocab=vocab,
    dropout=dropout,
    pad_token=pad_token,
    pad_value=pad_value,
    do_mvc=MVC,
    do_dab=DAB,
    use_batch_labels=INPUT_BATCH_LABELS,
    num_batch_labels=num_batch_types,
    domain_spec_batchnorm=config.DSBN,
    input_emb_style=input_emb_style,
    n_input_bins=n_input_bins,
    cell_emb_style=cell_emb_style,
    mvc_decoder_style=mvc_decoder_style,
    ecs_threshold=ecs_threshold,
    explicit_zero_prob=explicit_zero_prob,
    use_fast_transformer=fast_transformer,
    fast_transformer_backend=fast_transformer_backend,
    pre_norm=config.pre_norm,
)
if config.load_model is not None:
    try:
        model.load_state_dict(torch.load(model_file))
        logger.info(f"Loading all model params from {model_file}")
    except:
        # only load params that are in the model and match the size
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_file)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items()
            if k in model_dict and v.shape == model_dict[k].shape
        }
        for k, v in pretrained_dict.items():
            logger.info(f"Loading params {k} with shape {v.shape}")
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

# Freeze all pre-decoder weights
for name, para in model.named_parameters():
    print("-"*20)
    print(f"name: {name}")
    if config.freeze and "encoder" in name and "transformer_encoder" not in name:
    # if config.freeze and "encoder" in name:
        print(f"freezing weights for: {name}")
        para.requires_grad = False

post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())

logger.info(f"Total Pre freeze Params {(pre_freeze_param_count )}")
logger.info(f"Total Post freeze Params {(post_freeze_param_count )}")
wandb.log(
        {
            "info/pre_freeze_param_count": pre_freeze_param_count,
            "info/post_freeze_param_count": post_freeze_param_count,
        },
)

model.to(device)
wandb.watch(model)

if ADV:
    discriminator = AdversarialDiscriminator(
        d_model=embsize,
        n_cls=num_batch_types,
    ).to(device)

I have the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[27], [line 35](vscode-notebook-cell:?execution_count=27&line=35)
     [34](vscode-notebook-cell:?execution_count=27&line=34) try:
---> [35](vscode-notebook-cell:?execution_count=27&line=35)     model.load_state_dict(torch.load(model_file,map_location= torch.device('mps')))
     [36](vscode-notebook-cell:?execution_count=27&line=36)     logger.info(f"Loading all model params from {model_file}")

File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict(self, state_dict, strict, assign)
   [2188](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:2188) if len(error_msgs) > 0:
-> [2189](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:2189)     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   [2190](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:2190)                        self.__class__.__name__, "\n\t".join(error_msgs)))
   [2191](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:2191) return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for TransformerModel:
    Missing key(s) in state_dict: "transformer_encoder.layers.0.self_attn.in_proj_weight", "transformer_encoder.layers.0.self_attn.in_proj_bias", "transformer_encoder.layers.1.self_attn.in_proj_weight", "transformer_encoder.layers.1.self_attn.in_proj_bias", "transformer_encoder.layers.2.self_attn.in_proj_weight", "transformer_encoder.layers.2.self_attn.in_proj_bias", "transformer_encoder.layers.3.self_attn.in_proj_weight", "transformer_encoder.layers.3.self_attn.in_proj_bias", "transformer_encoder.layers.4.self_attn.in_proj_weight", "transformer_encoder.layers.4.self_attn.in_proj_bias", "transformer_encoder.layers.5.self_attn.in_proj_weight", "transformer_encoder.layers.5.self_attn.in_proj_bias", "transformer_encoder.layers.6.self_attn.in_proj_weight", "transformer_encoder.layers.6.self_attn.in_proj_bias", "transformer_encoder.layers.7.self_attn.in_proj_weight", "transformer_encoder.layers.7.self_attn.in_proj_bias", "transformer_encoder.layers.8.self_attn.in_proj_weight", "transformer_encoder.layers.8.self_attn.in_proj_bias", "transformer_encoder.layers.9.self_attn.in_proj_weight", "transformer_encoder.layers.9.self_attn.in_proj_bias", "transformer_encoder.layers.10.self_attn.in_proj_weight", "transformer_encoder.layers.10.self_attn.in_proj_bias", "transformer_encoder.layers.11.self_attn.in_proj_weight", "transformer_encoder.layers.11.self_attn.in_proj_bias", "cls_decoder._decoder.0.weight", "cls_decoder._decoder.0.bias", "cls_decoder._decoder.2.weight", "cls_decoder._decoder.2.bias", "cls_decoder._decoder.3.weight", "cls_decoder._decoder.3.bias", "cls_decoder._decoder.5.weight", "cls_decoder._decoder.5.bias", "cls_decoder.out_layer.weight", "cls_decoder.out_layer.bias". 
    Unexpected key(s) in state_dict: "flag_encoder.weight", "mvc_decoder.gene2query.weight", "mvc_decoder.gene2query.bias", "mvc_decoder.W.weight", "transformer_encoder.layers.0.self_attn.Wqkv.weight", "transformer_encoder.layers.0.self_attn.Wqkv.bias", "transformer_encoder.layers.1.self_attn.Wqkv.weight", "transformer_encoder.layers.1.self_attn.Wqkv.bias", "transformer_encoder.layers.2.self_attn.Wqkv.weight", "transformer_encoder.layers.2.self_attn.Wqkv.bias", "transformer_encoder.layers.3.self_attn.Wqkv.weight", "transformer_encoder.layers.3.self_attn.Wqkv.bias", "transformer_encoder.layers.4.self_attn.Wqkv.weight", "transformer_encoder.layers.4.self_attn.Wqkv.bias", "transformer_encoder.layers.5.self_attn.Wqkv.weight", "transformer_encoder.layers.5.self_attn.Wqkv.bias", "transformer_encoder.layers.6.self_attn.Wqkv.weight", "transformer_encoder.layers.6.self_attn.Wqkv.bias", "transformer_encoder.layers.7.self_attn.Wqkv.weight", "transformer_encoder.layers.7.self_attn.Wqkv.bias", "transformer_encoder.layers.8.self_attn.Wqkv.weight", "transformer_encoder.layers.8.self_attn.Wqkv.bias", "transformer_encoder.layers.9.self_attn.Wqkv.weight", "transformer_encoder.layers.9.self_attn.Wqkv.bias", "transformer_encoder.layers.10.self_attn.Wqkv.weight", "transformer_encoder.layers.10.self_attn.Wqkv.bias", "transformer_encoder.layers.11.self_attn.Wqkv.weight", "transformer_encoder.layers.11.self_attn.Wqkv.bias".

I did not change any code from the tutorial, just loaded the datasets to the correct folders.

To help solve this issue, check this, for the transformer model, the number of states is different between the loaded and the model architecture (yes I am loading on "mps" because using MacM3 that does not have CUDA via GPU, it is not related to this)

Captura de ecrã 2024-08-01, às 12 19 07

To help you further please check this:

Keys in the current model state_dict but not in the loaded state_dict:
{'cls_decoder._decoder.3.bias', 'transformer_encoder.layers.8.self_attn.in_proj_weight', 'transformer_encoder.layers.8.self_attn.in_proj_bias', 'transformer_encoder.layers.2.self_attn.in_proj_bias', 'transformer_encoder.layers.10.self_attn.in_proj_bias', 'transformer_encoder.layers.11.self_attn.in_proj_bias', 'transformer_encoder.layers.7.self_attn.in_proj_weight', 'transformer_encoder.layers.1.self_attn.in_proj_weight', 'transformer_encoder.layers.3.self_attn.in_proj_weight', 'transformer_encoder.layers.11.self_attn.in_proj_weight', 'cls_decoder._decoder.2.bias', 'transformer_encoder.layers.7.self_attn.in_proj_bias', 'cls_decoder._decoder.3.weight', 'transformer_encoder.layers.9.self_attn.in_proj_bias', 'cls_decoder._decoder.5.bias', 'cls_decoder.out_layer.bias', 'transformer_encoder.layers.9.self_attn.in_proj_weight', 'cls_decoder.out_layer.weight', 'transformer_encoder.layers.5.self_attn.in_proj_weight', 'transformer_encoder.layers.4.self_attn.in_proj_bias', 'transformer_encoder.layers.6.self_attn.in_proj_weight', 'transformer_encoder.layers.4.self_attn.in_proj_weight', 'cls_decoder._decoder.0.weight', 'cls_decoder._decoder.2.weight', 'transformer_encoder.layers.3.self_attn.in_proj_bias', 'cls_decoder._decoder.5.weight', 'transformer_encoder.layers.5.self_attn.in_proj_bias', 'transformer_encoder.layers.0.self_attn.in_proj_bias', 'transformer_encoder.layers.0.self_attn.in_proj_weight', 'transformer_encoder.layers.1.self_attn.in_proj_bias', 'cls_decoder._decoder.0.bias', 'transformer_encoder.layers.6.self_attn.in_proj_bias', 'transformer_encoder.layers.10.self_attn.in_proj_weight', 'transformer_encoder.layers.2.self_attn.in_proj_weight'}

Keys in the loaded state_dict but not in the current model state_dict:
{'transformer_encoder.layers.9.self_attn.Wqkv.weight', 'mvc_decoder.gene2query.bias', 'transformer_encoder.layers.3.self_attn.Wqkv.weight', 'transformer_encoder.layers.1.self_attn.Wqkv.weight', 'transformer_encoder.layers.3.self_attn.Wqkv.bias', 'transformer_encoder.layers.6.self_attn.Wqkv.bias', 'transformer_encoder.layers.4.self_attn.Wqkv.bias', 'transformer_encoder.layers.9.self_attn.Wqkv.bias', 'mvc_decoder.gene2query.weight', 'flag_encoder.weight', 'transformer_encoder.layers.1.self_attn.Wqkv.bias', 'transformer_encoder.layers.8.self_attn.Wqkv.bias', 'transformer_encoder.layers.2.self_attn.Wqkv.bias', 'transformer_encoder.layers.11.self_attn.Wqkv.weight', 'transformer_encoder.layers.10.self_attn.Wqkv.bias', 'transformer_encoder.layers.6.self_attn.Wqkv.weight', 'transformer_encoder.layers.7.self_attn.Wqkv.bias', 'transformer_encoder.layers.7.self_attn.Wqkv.weight', 'transformer_encoder.layers.2.self_attn.Wqkv.weight', 'transformer_encoder.layers.5.self_attn.Wqkv.bias', 'transformer_encoder.layers.8.self_attn.Wqkv.weight', 'transformer_encoder.layers.4.self_attn.Wqkv.weight', 'transformer_encoder.layers.5.self_attn.Wqkv.weight', 'transformer_encoder.layers.0.self_attn.Wqkv.weight', 'transformer_encoder.layers.11.self_attn.Wqkv.bias', 'transformer_encoder.layers.10.self_attn.Wqkv.weight', 'transformer_encoder.layers.0.self_attn.Wqkv.bias', 'mvc_decoder.W.weight'}

In sum, the model architecture does not match the loaded model on google drive.

WhenMelancholy commented 3 months ago

I've encountered a similar issue. Could someone please explain how I should resolve this problem?

ManuelSokolov commented 3 months ago

Hi, just updating on this issue. Solved it by running on a linux machine on aws