I followed a blog post that described how to glue FastBert and AWS Sagemaker. I may be missing something trivial, but from what I see the container/bert/predictor.py file does not expose an option to provide a batched input to the endpoint.
TL;DR
I needed to tweak the code in container/bert/predictor.py to be able to POST a batched input to AWS Sagemaker endpoint.
print("Invoked with text: {}.".format(text.encode("utf-8")))
which would fail if text is of type list or a dict. To bypass that, I changed the code to:
print("Invoked with text: {}.".format(str(text).encode("utf-8")))
In addition, I also increased the number of returned predictions, e.g.:
result = json.dumps(predictions[:200])
Also, as a quick & rough fix for my task, in the def predict(cls, text, bing_key=None): in the same file, I applied the following:
predictor_model = cls.get_predictor_model()
prediction = []
if isinstance(text, list):
prediction = predictor_model.predict_batch(text)
else:
if bing_key:
spellChecker = BingSpellCheck(bing_key)
text = spellChecker.spell_check(text)
prediction = predictor_model.predict(text)
return prediction
The above fix has worked for me, so now I am able to do inference on Sagemaker using a single input, e.g.:
{"text": "this is my non-toxic comment"}
or a batched input:
{"text": ["this is my non-toxic comment 1", "this is my non-toxic comment 2"]}
I am happy to issue a PR with a better version of the above code, unless I approached this blatantly wrong and there is already a way to POST a batched input to FastBert running in AWS Sagemaker.
Hello. This will work so
Will be great if you can Submit a PR. also would be cool if you can update it to support batch transform on SageMaker. That would be useful for a large test dataset.
Context
I followed a blog post that described how to glue FastBert and AWS Sagemaker. I may be missing something trivial, but from what I see the
container/bert/predictor.py
file does not expose an option to provide a batched input to the endpoint.TL;DR
I needed to tweak the code in
container/bert/predictor.py
to be able toPOST
a batched input to AWS Sagemaker endpoint.Details
There is an issue with the following line of code: https://github.com/kaushaltrivedi/fast-bert/blob/master/container/bert/predictor.py#L140
which would fail if
text
is of typelist
or adict
. To bypass that, I changed the code to:In addition, I also increased the number of returned predictions, e.g.:
result = json.dumps(predictions[:200])
Also, as a quick & rough fix for my task, in the
def predict(cls, text, bing_key=None):
in the same file, I applied the following:The above fix has worked for me, so now I am able to do inference on Sagemaker using a single input, e.g.:
or a batched input:
I am happy to issue a PR with a better version of the above code, unless I approached this blatantly wrong and there is already a way to
POST
a batched input to FastBert running in AWS Sagemaker.