utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.86k stars 341 forks source link

Error in `predict` for BertClassificationPredictor - logger is None #305

Open villalbamartin opened 2 years ago

villalbamartin commented 2 years ago

I believe I found a bug in the predict code for the BertClassificationPredictor class.

I am loading a trained model like this:

self.model = BertClassificationPredictor(model_path=model_dir,
                                         label_path=labels_dir,
                                         multi_label=True,
                                         model_type='distilbert',
                                         do_lower_case=False,
                                         device=None)

and then generating predictions like this:

pred = self.model.predict(text)

Looking at the source code, the first code sample creates a BertClassificationPredictor object that calls self.get_learner. This function then calls BertLearner.from_pretrained_model with the parameter logger=None.

When I then call predict, I am eventually led to BertLearner.predict_batch, which starts like this:

    def predict_batch(self, texts=None):
        if texts:
            self.logger.info("---PROGRESS-STATUS---: Tokenizing input texts...")
            dl = self.data.get_dl_from_texts(texts)
            self.logger.info("---PROGRESS-STATUS---: Tokenizing input texts...DONE")
        elif self.data.test_dl:
            dl = self.data.test_dl
        else:
            dl = self.data.val_dl

And because self.logger is None, the code fails with the following exception (paths redacted):

  File "/.../bert_model.py", line 192, in classify
    pred = self.model.predict(text)
  File "/.../fast_bert/prediction.py", line 79, in predict
    predictions = self.predict_batch([text])[0]
  File "/.../fast_bert/prediction.py", line 76, in predict_batch
    return self.learner.predict_batch(texts)
  File "/.../fast_bert/learner_cls.py", line 553, in predict_batch
    self.logger.info("---PROGRESS-STATUS---: Tokenizing input texts...")
AttributeError: 'NoneType' object has no attribute 'info'

As a workaround, I am currently manually setting a logger:

self.model = BertClassificationPredictor(...)
self.model.learner.logger = logging.getLogger()

But I am wondering whether this is truly a bug or whether I'm skipping an important step somewhere.