flairNLP / flair

A very simple framework for state-of-the-art Natural Language Processing (NLP)
https://flairnlp.github.io/flair/
Other
13.81k stars 2.09k forks source link

Can you use Flair models on Android? #1423

Closed juni-vogt closed 3 years ago

juni-vogt commented 4 years ago

I trained a binary text classification model with Flair using glove embeddings as in the tutorial from the repo, which was fun and worked very nicely. Now I would like to use it in an Android app. Is this possible right now?

Since Flair is an extension of PyTorch, I think not everything that works with PyTorch models also works with Flair models. There are, however, several methods of using PyTorch models on Android, so maybe my question is how to make these methods work with Flair models as well. For example, it is possible to use PyTorch Mobile, by first converting a PyTorch model to a TorchScript model and then using this with the PyTorch Mobile Android library. I tried doing this using the following code, but I got an error.

Code and Error Model training ```python from flair.data import Corpus from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentRNNEmbeddings from flair.models import TextClassifier from flair.trainers import ModelTrainer # 1. get the corpus # corpus: Corpus = TREC_6() from pathlib import Path from flair.data_fetcher import NLPTaskDataFetcher datasetinfo = open("../flair_grooming/modified_dataset/datasetinfo.txt", "r") print("################################") print("Using dataset `%s'" % datasetinfo.read()) print("Using `%s' embeddings" % datasetinfo.read()) print("################################") embeddings = "glove" corpus = NLPTaskDataFetcher.load_classification_corpus( Path('../flair_grooming/modified_dataset/'), test_file='test.csv', dev_file='dev.csv', train_file='train.csv' ) # 2. create the label dictionary label_dict = corpus.make_label_dictionary() # ["predator", "non-predator"] # print (label_dict) # 3. make a list of word embeddings word_embeddings = [ WordEmbeddings('glove'), # flair embeddings for state-of-the-art results # FlairEmbeddings('news-forward-fast'), # FlairEmbeddings('news-backward-fast'), ] # 4. initialize document embedding by passing list of word embeddings # Can choose between many RNN types (GRU by default, to change use rnn_type parameter) document_embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings(word_embeddings, hidden_size=512, reproject_words=True, reproject_words_dimension=256, ) # from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentLSTMEmbeddings # document_embeddings = DocumentLSTMEmbeddings( # word_embeddings, # hidden_size=512, # reproject_words=True, # reproject_words_dimension=256 # ) # 5. create the text classifier classifier = TextClassifier( document_embeddings, label_dictionary=label_dict, beta=.5, # beta for F-Score, we use a F_{0.5}-score multi_label=False, # binary classification ) # 6. initialize the text classifier trainer trainer = ModelTrainer( classifier, corpus ) # 7. start the training trainer.train('networks/pan12-%s' % embeddings, learning_rate=0.1, mini_batch_size=32, anneal_factor=0.5, patience=5, max_epochs=150, embeddings_storage_mode='gpu', checkpoint=True ) # 8. plot weight traces (optional) from flair.visual.training_curves import Plotter plotter = Plotter() plotter.plot_weights('networks/pan12-%s/weights.txt' % embeddings) ``` Model conversion to TorchScript ```python import torch from flair.models import TextClassifier embedding = "glove" print("converting model with `%s' embeddings" % embedding) classifier = TextClassifier.load('networks/pan12-%s/best-model.pt' % embedding) # torch.jit.script expects a torch.nn.Module # TextClassifier is a child of flair.nn.Model is a child of torch.nn.Module traced_script_module = torch.jit.script(classifier) print("converted to script-module") traced_script_module.save("script-module-model.pt") print("saved") ``` Conversion error ``` converting model with `glove' embeddings 2020-02-09 18:54:56,127 loading file networks/pan12-glove/best-model.pt Traceback (most recent call last): File "to_mobile.py", line 24, in traced_script_module = torch.jit.script(classifier) File "/home/vogtmatt/.local/lib/python3.6/site-packages/torch/jit/__init__.py", line 1162, in script return _convert_to_script_module(obj) File "/home/vogtmatt/.local/lib/python3.6/site-packages/torch/jit/__init__.py", line 1812, in _convert_to_script_module return WeakScriptModuleProxy(mod, stubs) File "/home/vogtmatt/.local/lib/python3.6/site-packages/torch/jit/__init__.py", line 1386, in init_then_register original_init(self, *args, **kwargs) File "/home/vogtmatt/.local/lib/python3.6/site-packages/torch/jit/__init__.py", line 1736, in __init__ _create_methods_from_stubs(self, stubs) File "/home/vogtmatt/.local/lib/python3.6/site-packages/torch/jit/__init__.py", line 1347, in _create_methods_from_stubs self._c._create_methods(self, defs, rcbs, defaults) File "/home/vogtmatt/.local/lib/python3.6/site-packages/torch/jit/__init__.py", line 982, in _make_strong_submodule new_strong_submodule = _convert_to_script_module(module) File "/home/vogtmatt/.local/lib/python3.6/site-packages/torch/jit/__init__.py", line 1796, in _convert_to_script_module raise RuntimeError("No forward method was defined on {}".format(mod)) RuntimeError: No forward method was defined on DocumentRNNEmbeddings( (embeddings): StackedEmbeddings( (list_embedding_0): WordEmbeddings('glove') ) (word_reprojection_map): Linear(in_features=100, out_features=256, bias=True) (rnn): GRU(256, 512, batch_first=True) (dropout): Dropout(p=0.5, inplace=False) ) ```

Is it possible to convert Flair Models to TorchScript right now? It might also be possible to use ONNX to convert Flair's PyTorch models to TensorFlow models, which could then be used on Android with TensorFlow Lite, but I haven't tried this yet. Thanks for your help!

alanakbik commented 4 years ago

@kashif any ideas how this could be done?

svmihar commented 4 years ago

any updates? got the same error

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.