device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ntokens = len(vocab) # size of vocabulary
model = TransformerModel(
ntokens,
embsize,
nhead,
d_hid,
nlayers,
nlayers_cls=3,
n_cls=num_types if CLS else 1,
vocab=vocab,
dropout=dropout,
pad_token=pad_token,
pad_value=pad_value,
do_mvc=MVC,
do_dab=DAB,
use_batch_labels=INPUT_BATCH_LABELS,
num_batch_labels=num_batch_types,
domain_spec_batchnorm=config.DSBN,
input_emb_style=input_emb_style,
n_input_bins=n_input_bins,
cell_emb_style=cell_emb_style,
mvc_decoder_style=mvc_decoder_style,
ecs_threshold=ecs_threshold,
explicit_zero_prob=explicit_zero_prob,
use_fast_transformer=fast_transformer,
fast_transformer_backend=fast_transformer_backend,
pre_norm=config.pre_norm,
)
if config.load_model is not None:
try:
model.load_state_dict(torch.load(model_file))
logger.info(f"Loading all model params from {model_file}")
except:
# only load params that are in the model and match the size
model_dict = model.state_dict()
pretrained_dict = torch.load(model_file)
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if k in model_dict and v.shape == model_dict[k].shape
}
for k, v in pretrained_dict.items():
logger.info(f"Loading params {k} with shape {v.shape}")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
# Freeze all pre-decoder weights
for name, para in model.named_parameters():
print("-"*20)
print(f"name: {name}")
if config.freeze and "encoder" in name and "transformer_encoder" not in name:
# if config.freeze and "encoder" in name:
print(f"freezing weights for: {name}")
para.requires_grad = False
post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
logger.info(f"Total Pre freeze Params {(pre_freeze_param_count )}")
logger.info(f"Total Post freeze Params {(post_freeze_param_count )}")
wandb.log(
{
"info/pre_freeze_param_count": pre_freeze_param_count,
"info/post_freeze_param_count": post_freeze_param_count,
},
)
model.to(device)
wandb.watch(model)
if ADV:
discriminator = AdversarialDiscriminator(
d_model=embsize,
n_cls=num_batch_types,
).to(device)
I did not change any code from the tutorial, just loaded the datasets to the correct folders.
To help solve this issue, check this, for the transformer model, the number of states is different between the loaded and the model architecture (yes I am loading on "mps" because using MacM3 that does not have CUDA via GPU, it is not related to this)
To help you further please check this:
Keys in the current model state_dict but not in the loaded state_dict:
{'cls_decoder._decoder.3.bias', 'transformer_encoder.layers.8.self_attn.in_proj_weight', 'transformer_encoder.layers.8.self_attn.in_proj_bias', 'transformer_encoder.layers.2.self_attn.in_proj_bias', 'transformer_encoder.layers.10.self_attn.in_proj_bias', 'transformer_encoder.layers.11.self_attn.in_proj_bias', 'transformer_encoder.layers.7.self_attn.in_proj_weight', 'transformer_encoder.layers.1.self_attn.in_proj_weight', 'transformer_encoder.layers.3.self_attn.in_proj_weight', 'transformer_encoder.layers.11.self_attn.in_proj_weight', 'cls_decoder._decoder.2.bias', 'transformer_encoder.layers.7.self_attn.in_proj_bias', 'cls_decoder._decoder.3.weight', 'transformer_encoder.layers.9.self_attn.in_proj_bias', 'cls_decoder._decoder.5.bias', 'cls_decoder.out_layer.bias', 'transformer_encoder.layers.9.self_attn.in_proj_weight', 'cls_decoder.out_layer.weight', 'transformer_encoder.layers.5.self_attn.in_proj_weight', 'transformer_encoder.layers.4.self_attn.in_proj_bias', 'transformer_encoder.layers.6.self_attn.in_proj_weight', 'transformer_encoder.layers.4.self_attn.in_proj_weight', 'cls_decoder._decoder.0.weight', 'cls_decoder._decoder.2.weight', 'transformer_encoder.layers.3.self_attn.in_proj_bias', 'cls_decoder._decoder.5.weight', 'transformer_encoder.layers.5.self_attn.in_proj_bias', 'transformer_encoder.layers.0.self_attn.in_proj_bias', 'transformer_encoder.layers.0.self_attn.in_proj_weight', 'transformer_encoder.layers.1.self_attn.in_proj_bias', 'cls_decoder._decoder.0.bias', 'transformer_encoder.layers.6.self_attn.in_proj_bias', 'transformer_encoder.layers.10.self_attn.in_proj_weight', 'transformer_encoder.layers.2.self_attn.in_proj_weight'}
Keys in the loaded state_dict but not in the current model state_dict:
{'transformer_encoder.layers.9.self_attn.Wqkv.weight', 'mvc_decoder.gene2query.bias', 'transformer_encoder.layers.3.self_attn.Wqkv.weight', 'transformer_encoder.layers.1.self_attn.Wqkv.weight', 'transformer_encoder.layers.3.self_attn.Wqkv.bias', 'transformer_encoder.layers.6.self_attn.Wqkv.bias', 'transformer_encoder.layers.4.self_attn.Wqkv.bias', 'transformer_encoder.layers.9.self_attn.Wqkv.bias', 'mvc_decoder.gene2query.weight', 'flag_encoder.weight', 'transformer_encoder.layers.1.self_attn.Wqkv.bias', 'transformer_encoder.layers.8.self_attn.Wqkv.bias', 'transformer_encoder.layers.2.self_attn.Wqkv.bias', 'transformer_encoder.layers.11.self_attn.Wqkv.weight', 'transformer_encoder.layers.10.self_attn.Wqkv.bias', 'transformer_encoder.layers.6.self_attn.Wqkv.weight', 'transformer_encoder.layers.7.self_attn.Wqkv.bias', 'transformer_encoder.layers.7.self_attn.Wqkv.weight', 'transformer_encoder.layers.2.self_attn.Wqkv.weight', 'transformer_encoder.layers.5.self_attn.Wqkv.bias', 'transformer_encoder.layers.8.self_attn.Wqkv.weight', 'transformer_encoder.layers.4.self_attn.Wqkv.weight', 'transformer_encoder.layers.5.self_attn.Wqkv.weight', 'transformer_encoder.layers.0.self_attn.Wqkv.weight', 'transformer_encoder.layers.11.self_attn.Wqkv.bias', 'transformer_encoder.layers.10.self_attn.Wqkv.weight', 'transformer_encoder.layers.0.self_attn.Wqkv.bias', 'mvc_decoder.W.weight'}
In sum, the model architecture does not match the loaded model on google drive.
Hi, thank you for your amazing work developing scGPT.
I am following the TutorialAnnotation.ipynb tutorial using the ms dataset from https://drive.google.com/drive/folders/1Qd42YNabzyr2pWt9xoY4cVMTAxsNBt4v and the model from https://drive.google.com/drive/folders/1oWh-ZRdhtoGQ2Fw24HP41FgLoomVo-y
For this section of the tutorial (model loading):
I have the following error:
I did not change any code from the tutorial, just loaded the datasets to the correct folders.
To help solve this issue, check this, for the transformer model, the number of states is different between the loaded and the model architecture (yes I am loading on "mps" because using MacM3 that does not have CUDA via GPU, it is not related to this)
To help you further please check this:
In sum, the model architecture does not match the loaded model on google drive.