dhlee347 / pytorchic-bert

Pytorch Implementation of Google BERT
Apache License 2.0
591 stars 179 forks source link

questions for loading the pretrained_model #21

Open mingbocui opened 5 years ago

mingbocui commented 5 years ago
    def load(self, model_file, pretrain_file):
        """ load saved model or pretrained transformer (a part of model) """
        if model_file:
            print('Loading the model from', model_file)
            self.model.load_state_dict(torch.load(model_file))

        elif pretrain_file: # use pretrained transformer
            print('Loading the pretrained model from', pretrain_file)
            if pretrain_file.endswith('.ckpt'): # checkpoint file in tensorflow
                checkpoint.load_model(self.model.transformer, pretrain_file)
            elif pretrain_file.endswith('.pt'): # pretrain model file in pytorch
                self.model.transformer.load_state_dict(
                    {key[12:]: value
                        for key, value in torch.load(pretrain_file).items()
                        if key.startswith('transformer')}
                ) # load only transformer parts

Could I kindly ask that what is the meaning of key[12:]: value when you load a pretrained_model? Just want to keep the last layer? Thanks, hope for your reply.

dhlee347 commented 5 years ago

It is because I wanted to load only a transformer part of saved model, not the whole model.

mingbocui commented 4 years ago

@dhlee347 thanks for your reply. I have one more question, if I change the number of BERT layers from 12 to 6, should I change the key[12:] to key[6:]?