lavis-nlp / spert

PyTorch code for SpERT: Span-based Entity and Relation Transformer
MIT License
691 stars 148 forks source link

Help! Help! #53

Closed HuizhaoWang closed 3 years ago

HuizhaoWang commented 3 years ago

Hi, I am following the paper and your code, and want to modify in some parts, but I have encountered a difficulty that has troubled me for a long time, now seeking your help very sincerely. There are two different text-data resoures, they all need to be embedded into vectors using the pretrained_model, however, because the texts from different resoures have the different features (have different semantic information), texts from different resoures should be encoded using the different pretrained_models (such as "bert-base-cased" and "sciecr-****" ) . In the modify process, I found only one pretrained_model can be loaded, and the other one(pretrain model) can not, the weights are newly initialized.

  1. the following in terminal is shown.

Some weights of SpERT were not initialized from the model checkpoint at pretrained_models/scibert_scivocab_cased and are newly initialized: ['bert_addition.embeddings.word_embeddings.weight', 'bert_addition.embeddings.position_embeddings.weight', 'bert_addition.embeddings.token_type_embeddings.weight', 'bert_addition.embeddings.LayerNorm.weight', 'bert_addition.embeddings.LayerNorm.bias', 'bert_addition.encoder.layer.0.attention.self.query.weight'

  1. The related modified parts (based the code you shared) as follow. in spert_trainer.py `def _load_model(self, input_reader):         model_class = models.get_model(self._args.model_type) ## return Class "Spert"

        config = BertConfig.from_pretrained(self._args.model_path, cache_dir=self._args.cache_path)

        addition_config = BertConfig.from_pretrained(self._args.addition_model_path, cache_dir=self._args.cache_path)

        util.check_version(config, model_class, self._args.model_path)

        config.spert_version = model_class.VERSION         model = model_class.from_pretrained(self._args.model_path,                                             config=config,                                             # SpERT model parameters                                             addition_config = addition_config,                                             cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),                                             relation_types=input_reader.relation_type_count-1, `

in model.py def init(self, config: BertConfig, addition_config:BertConfig,                  cls_token: int, relation_types: int, entity_types: int):          super(SpERT, self).init(config)         
         # BERT model          self.bert = BertModel(config)          self.bert_addition = BertModel(addition_config)

Can you provide code to explain how to modify? Thank U.

markus-eberts commented 3 years ago

Hi, this is not a SpERT issue/question, but rather related to the Transformers library we are using. Transformers has a very nice documentation, I'm sure you will find a way to adjust the model to your needs. In your case, you probably need to load the pretrained BERT models by calling 'BertModel.from_pretrained' (inside the SpERT model) instead of using 'model_class.from_pretrained' (in spert_trainer.py), like so:

def init(self, config: BertConfig, addition_config:BertConfig,
                 cls_token: int, relation_types: int, entity_types: int, model_name_or_path1: str, model_name_or_path2):
         super(SpERT, self).init(config)

      # layers
      self.rel_classifier = nn.Linear(config.hidden_size * 3 + size_embedding * 2, relation_types)
      self.entity_classifier = nn.Linear(config.hidden_size * 2 + size_embedding, entity_types)
      self.size_embeddings = nn.Embedding(100, size_embedding)
      self.dropout = nn.Dropout(prop_drop)

      # weight initialization before loading BERT models, so BERT's parameters are not overwritten
      self.init_weights()

      self._cls_token = cls_token
      self._relation_types = relation_types
      self._entity_types = entity_types
      self._max_pairs = max_pairs

      # BERT models
      self.bert = BertModel.from_pretrained(model_name_or_path1)
      self.bert_addition = BertModel.from_pretrained(model_name_or_path2)