utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.86k stars 341 forks source link

Creation of BertClassificationPredictor breaks #319

Open pbouss opened 9 months ago

pbouss commented 9 months ago

The BertClassificationPredictor class cannot be initiated. It creates a databunch without a train-file but later needs a train_dl of which the BertLearner.from_pretrained_model needs the len. An example error message is:

bert_model.load(Path)

File "....", line ..., in load self.model = BertClassificationPredictor(model_path=model_dir, File ".../miniconda3/envs/bert/lib/python3.8/site-packages/fast_bert/prediction.py", line 48, in init self.learner = self.get_learner() File ".../miniconda3/envs/bert/lib/python3.8/site-packages/fast_bert/prediction.py", line 68, in get_learner learner = BertLearner.from_pretrained_model( File ".../miniconda3/envs/bert/lib/python3.8/site-packages/fast_bert/learner_cls.py", line 246, in from_pretrained_model return BertLearner( File ".../miniconda3/envs/bert/lib/python3.8/site-packages/fast_bert/learner_cls.py", line 346, in init t_total = len(train_dataloader) // self.grad_accumulation_steps * epochs TypeError: object of type 'NoneType' has no len()

This could be fixed by adding a flag here in learner_cls.py: (ll. 339-348) if self.max_steps > 0: t_total = self.max_steps self.epochs = ( self.max_steps // len(train_dataloader) // self.grad_accumulation_steps