yzhangcs / parser

:rocket: State-of-the-art parsers for natural language.
https://parser.yzhang.site/
MIT License
829 stars 141 forks source link

AutoModel #29

Closed attardi closed 4 years ago

attardi commented 4 years ago

What about adding the ability to use any model from HuggingFace, besides BERT?

It is enough to define a subclass of BertEmbedding like this:

from transformers import AutoModel, AutoConfig

class AutoEmbedding(BertEmbedding):

def __init__(self, model, n_layers, n_out, pad_index=0,
             requires_grad=False):
    super(BertEmbedding, self).__init__()

    config = AutoConfig.from_pretrained(model)
    config.output_hidden_states = True
    self.bert = AutoModel.from_pretrained(model, config=config)
    self.bert.config.output_hidden_states = True
    self.bert = self.bert.requires_grad_(requires_grad)
    self.n_layers = n_layers
    self.n_out = n_out
    self.pad_index = pad_index
    self.requires_grad = requires_grad
    self.hidden_size = self.bert.config.hidden_size

    self.scalar_mix = ScalarMix(n_layers)
    if self.hidden_size != n_out:
        self.projection = nn.Linear(self.hidden_size, n_out, False)
yzhangcs commented 4 years ago

You only need to modify BertModel and BertTokenizer to other models you need, e.g., XLNet.

attardi commented 4 years ago

My suggestion will make the code more flexible, since you don't have to modify the code to use other models AutoTokenizer and AutoModel take care of choosing the right classes depending on the pertained model name. You need just to set bert_model and it will load the right tokenizer and the right model with AutoModel. I am testing this with Finnish and Italian languages. I am getting an improvement of about 2 LAS points with the language specific models.

yzhangcs commented 4 years ago

Yeah, you are right. Recently, I'm going to release the code as a python package. I would add this feature. Thanks for your suggestion.

attardi commented 4 years ago

The only other change is needed in CMD.__call__()

       elif args.feat == 'bert':
            if args.bert_model.startswith('bert'):
                from transformers import BertTokenizer
                tokenizer = BertTokenizer.from_pretrained(args.bert_model)
            else:           # BERT models from other authors on https://huggingface.co/models                                 
                from transformers import AutoTokenizer
                tokenizer = AutoTokenizer.from_pretrained(args.bert_model)
            self.FEAT = SubwordField('bert',
                                     pad=tokenizer.pad_token,
                                     unk=tokenizer.unk_token,
                                     bos=tokenizer.cls_token,
                                     fix_len=args.fix_len,
                                     tokenize=tokenizer.tokenize)
            self.FEAT.vocab = tokenizer.vocab

but possibly you can just get rid of BerrTokenizer case.