utterworks / fast-bert

Super easy library for BERT based NLP models
Apache License 2.0
1.86k stars 341 forks source link

How to load model.safetensors using fast_bert using python3.* #320

Open pratikchhapolika opened 6 months ago

pratikchhapolika commented 6 months ago

I have trained a model using latest transformers package 4.37.2 and it saves the model as model.safetensors.

I am using from fast_bert.prediction import BertClassificationPredictor to load the model in production. But it's giving the error as

OSError: Error no file named ['pytorch_model.bin', 'tf_model.h5', 'model.ckpt.index'] found in directory  or `from_tf` set to False 

Does BertClassificationPredictor allow us to load model.safetensors. If yes, how can we do this?

predictor = BertClassificationPredictor(model_path=self.model_path,
                                                     label_path=self.label_path,
                                                     multi_label=False,
                                                     model_type='roberta',
                                                     do_lower_case=True)

@kaushaltrivedi