amaiya / ktrain

ktrain is a Python library that makes deep learning and AI more accessible and easier to apply
Apache License 2.0
1.23k stars 268 forks source link

DistilBert / XLM #27

Closed kinoute closed 4 years ago

kinoute commented 4 years ago

Hi,

Thanks for this great library. Coming from the FastAI course, it looks so familiar and easy to use! Great job. I was wondering if you could add some new models, especially the DistilBert one from HuggingFace that seems lighter and 60% faster to train than BERT while keeping 97% of BERT's power according to their paper.

XLM From Facebook seems also a to be a great cross-lingual model : https://github.com/facebookresearch/XLM#ii-cross-lingual-language-model-pretraining-xlm

Thank you!

amaiya commented 4 years ago

@kinoute Yes, the plan is to integrate HuggingFace models into ktrain eventually. I'd like to add ALBERT, too. Thanks for your comment.

amaiya commented 4 years ago

As of v0.8.x, ktrain includes a thin wrapper to the Hugging Face transformers library for text classification:

import ktrain
from ktrain import text
MODEL_NAME = 'distilbert-base-uncased'
t = text.Transformer(MODEL_NAME, maxlen=500, classes=train_b.target_names)
trn = t.preprocess_train(x_train, y_train)
val = t.preprocess_test(x_test, y_test)
model = t.get_classifier()
learner = ktrain.get_learner(model, train_data=trn, val_data=val, batch_size=6)
learner.fit_onecycle(3e-5, 4)

TF versions of some of the models seem to produce errors in v2.3.0 of transformers. There are to-be-released patches for some problems in the PRs on Hugging Face's GitHub but not for others. The distilbert-base-uncased model does work, though, and trains in half the time as BERT. See the tutorial notebook for usage instructions.

The English distilbert model can also be accessed using the conventional API, as shown here.