utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.85k stars 341 forks source link

[FastBert on AWS Sagemaker] container/bert/predictor.py does not accept batched input #125

Open azagniotov opened 4 years ago

azagniotov commented 4 years ago

Context

I followed a blog post that described how to glue FastBert and AWS Sagemaker. I may be missing something trivial, but from what I see the container/bert/predictor.py file does not expose an option to provide a batched input to the endpoint.

TL;DR

I needed to tweak the code in container/bert/predictor.py to be able to POST a batched input to AWS Sagemaker endpoint.

Details

There is an issue with the following line of code: https://github.com/kaushaltrivedi/fast-bert/blob/master/container/bert/predictor.py#L140

print("Invoked with text: {}.".format(text.encode("utf-8")))

which would fail if text is of type list or a dict. To bypass that, I changed the code to:

print("Invoked with text: {}.".format(str(text).encode("utf-8")))

In addition, I also increased the number of returned predictions, e.g.: result = json.dumps(predictions[:200])

Also, as a quick & rough fix for my task, in the def predict(cls, text, bing_key=None): in the same file, I applied the following:

        predictor_model = cls.get_predictor_model()
        prediction = []
        if isinstance(text, list):
          prediction = predictor_model.predict_batch(text)
        else:
          if bing_key:
            spellChecker = BingSpellCheck(bing_key)
            text = spellChecker.spell_check(text)
          prediction = predictor_model.predict(text)

        return prediction

The above fix has worked for me, so now I am able to do inference on Sagemaker using a single input, e.g.:

{"text": "this is my non-toxic comment"}

or a batched input:

{"text": ["this is my non-toxic comment 1", "this is my non-toxic comment 2"]}

I am happy to issue a PR with a better version of the above code, unless I approached this blatantly wrong and there is already a way to POST a batched input to FastBert running in AWS Sagemaker.

kaushaltrivedi commented 4 years ago

Hello. This will work so Will be great if you can Submit a PR. also would be cool if you can update it to support batch transform on SageMaker. That would be useful for a large test dataset.