alexriggio / BERT-LoRA-TensorRT

This repository contains a custom implementation of the BERT model, fine-tuned for specific tasks, along with an implementation of Low Rank Approximation (LoRA). The models are optimized for high performance using NVIDIA's TensorRT.
Apache License 2.0
46 stars 6 forks source link

AssertionError: mismatched keys #2

Closed martijnsiepel01 closed 8 months ago

martijnsiepel01 commented 8 months ago

I tried to run the provided fine_tuning notebook. However, when I try to fine-tune I get the following error:

Loading weights from pretrained model: bert-base-uncased Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

AssertionError Traceback (most recent call last) in <cell line: 3>() 1 # load tokenizer and pretrained model 2 tokenizer_base = BertTokenizer.from_pretrained('bert-base-uncased') ----> 3 bert_base = MyBertForSequenceClassification.from_pretrained( 4 model_type='bert-base-uncased', 5 config_args={"vocab_size": 30522, "n_classes": 2} # these are default configs but just added for explicity

/content/bert_from_scratch.py in from_pretrained(cls, model_type, config_args, adaptive_weight_copy) 300 301 # Check that all keys match between the state dictionary of the custom and pretrained model --> 302 assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 303 304 # Replace weights in the custom model with the weights from the pretrained model

AssertionError: mismatched keys: 201 != 202

What could cause this?

alexriggio commented 8 months ago

It sounds like Hugging Face updated their model and now the state dictionary has one less item.

Try using transformers==4.28.1.

You can find mention of this issue here #1.

Hope that helps and let me know if you encounter further issues.

martijnsiepel01 commented 8 months ago

This fixed it, thanks!