kpe / bert-for-tf2

A Keras TensorFlow 2.0 implementation of BERT, ALBERT and adapter-BERT.
https://github.com/kpe/bert-for-tf2
MIT License
802 stars 193 forks source link

What would be a good way to pad input texts? #30

Closed hygkim95 closed 4 years ago

hygkim95 commented 4 years ago

Currently, I am just adding 0s to token_ids to match max_seq_len. `def tokenize_text(texts): model_name = "albert_base" max_seq_len = 64 model_dir = bert.fetch_tfhub_albert_model(model_name, ".models") spm_model = os.path.join(model_dir, "assets", "30k-clean.model") sp = spm.SentencePieceProcessor() sp.load(spm_model) do_lower_case = True

tokenized = []
for text in texts:
    processed_text = bert.albert_tokenization.preprocess_text(text, lower=do_lower_case)
    token_ids = bert.albert_tokenization.encode_ids(sp, processed_text)
    token_ids = np.append(token_ids, np.zeros(max_seq_len-len(token_ids)))
    tokenized.append(token_ids)
return np.array(tokenized)`

However I found out even the zero tokens were embedded to non-zero vectors. Is this something I have to worry about? If it is, what is the proper way of padding input texts?

muhammadfahid51 commented 4 years ago

Have you been able to make prediction. Can you please give me a flow of how to make prediction using any pre-trained model ?

kpe commented 4 years ago

@hygkim95 - yes, it is not expected to have zero activations for the padding tokens (i.e. the [PAD] token with an id of 0), unless you explicitly fine-tune/train for this (which is not necessary or common). Just make sure to ignore the activations for the padding (if you use the token embeddings at all - i.e. which is usually not the case for example when doing sequence classification, i.e. like sentiment analysis).

kpe commented 4 years ago

@muhammadfahid51 - the pre-trained models could either be used to generate representations/embeddings you could use in a follow-up model, or by fine-tuning the BERT model. After that, i.e. training/fine-tuning, in both cases you'd call model.predict() to get the model activations/loggits.