Open sg85mang opened 5 years ago
its an issue inside BertClassificationPredictor. Looks like it has something to do with the model_path param.
The error has to do with the fact that the first parameter for BertClassificationPredictor method requires the model path input to be of type string. However, you have defined it to be of type PosixPath in MODEL_PATH = OUTPUT_DIR/'model_out' (to be more explicit, the Path() function always returns a PosixPath object). This issue can be fixed simply by redefining the path as a string object so like str(MODEL_PATH).
Hi, I'm following the guide, and everything seems to work except when I'm creating a predictor object I get:
File "bert.py", line 63, in
do_lower_case=False)
File "/home/w3pt/.local/lib/python3.7/site-packages/fast_bert/prediction.py", line 38, in init
self.learner = self.get_learner()
File "/home/w3pt/.local/lib/python3.7/site-packages/fast_bert/prediction.py", line 45, in get_learner
self.model_path, do_lower_case=self.do_lower_case)
File "/home/w3pt/.local/lib/python3.7/site-packages/pytorch_transformers/tokenization_utils.py", line 293, in from_pretrained
return cls._from_pretrained(*inputs, *kwargs)
File "/home/w3pt/.local/lib/python3.7/site-packages/pytorch_transformers/tokenization_utils.py", line 421, in _from_pretrained
tokenizer = cls(init_inputs, **init_kwargs)
File "/home/w3pt/.local/lib/python3.7/site-packages/pytorch_transformers/tokenization_xlnet.py", line 90, in init
self.sp_model.Load(vocab_file)
File "/home/w3pt/.local/lib/python3.7/site-packages/sentencepiece.py", line 118, in Load
return _sentencepiece.SentencePieceProcessor_Load(self, filename)
TypeError: not a string
The code is below: from langdetect import detect from goose3 import Goose import re from fast_bert.data_cls import BertDataBunch from pathlib import Path
from fast_bert.learner_cls import BertLearner from fast_bert.metrics import accuracy import logging import torch from fast_bert.prediction import BertClassificationPredictor
logger = logging.getLogger() device_cuda = torch.device("cuda") metrics = [{'name': 'accuracy', 'function': accuracy}]
DATA_PATH = Path('./') LABEL_PATH = Path('./') OUTPUT_DIR = Path('./')
MODEL_PATH = OUTPUT_DIR/'model_out'
databunch = BertDataBunch(DATA_PATH, LABEL_PATH, tokenizer='bert-base-uncased', train_file='train.csv', val_file='val.csv', label_file='labels.csv', text_col='text', label_col='label', batch_size_per_gpu=16, max_seq_length=512, multi_gpu=True, multi_label=False, model_type='bert')
learner = BertLearner.from_pretrained_model( databunch, pretrained_path='bert-base-uncased', metrics=metrics, device=device_cuda, logger=logger, output_dir=OUTPUT_DIR, finetuned_wgts_path=None, warmup_steps=500, multi_gpu=True, is_fp16=True, multi_label=False, logging_steps=50) learner.fit(epochs=6, lr=6e-5, validate=True, # Evaluate the model after each epoch schedule_type="warmup_cosine", optimizer_type="lamb")
learner.save_model() predictor = BertClassificationPredictor( model_path=MODEL_PATH, label_path=LABEL_PATH, # location for labels.csv file multi_label=False, model_type='xlnet', do_lower_case=False)
while True: th = [] gurl = input("enter site:") if not 'http' in gurl: gurl = "http://" + gurl g = Goose({'http_timeout': 60.0}) try: article = g.extract(url=gurl) except Exception as e: print ("goose failed. - %s" % e) continue title = article.title meta_desc = article.meta_description content = title + "." + meta_desc + "." + article.cleaned_text content = content.replace(",","").replace("\n","").replace("\r","") try: lang = detect(content) except: continue if lang != "en": print ("skipping non english") continue th.append(content) print(th) predictions = predictor.predict(content) print (predictions)