cdqa-suite / cdQA

⛔ [NOT MAINTAINED] An End-To-End Closed Domain Question Answering System.
https://cdqa-suite.github.io/cdQA-website/
Apache License 2.0
615 stars 191 forks source link

NotFittedError: BM25Vectorizer - Vocabulary wasn't fitted. #328

Open fin-amal-joseph opened 4 years ago

fin-amal-joseph commented 4 years ago

I 've got an issue when predicting adding sample code and error below

cdqa_pipeline = QAPipeline(reader='bert_models/bert_qa.joblib') cdqa_pipeline.fit_reader('bert_models/SQuAD_1.1/train-v1.1.json') cdqa_pipeline = QAPipeline(reader='bert_out.joblib') cdqa_pipeline.predict(query="Who is chaplin?")


NotFittedError Traceback (most recent call last)

in () ----> 1 cdqa_pipeline.predict(query="Who is chaplin?") /usr/local/lib/python3.6/dist-packages/sklearn/utils/validation.py in check_is_fitted(estimator, attributes, msg, all_or_any) 912 913 if not all_or_any([hasattr(estimator, attr) for attr in attributes]): --> 914 raise NotFittedError(msg % {'name': type(estimator).__name__}) 915 916 NotFittedError: BM25Vectorizer - Vocabulary wasn't fitted.
tianpaul01 commented 4 years ago

I experience the same problem when I load my new dataset. Is this fix?

fmikaelian commented 4 years ago

You have to fit the pipeline retriever to the dataframe with the documents before calling the predict() function:

cdqa_pipeline.fit_retriever(df=df)

sftblw commented 4 years ago

I had a similar problem because I thought cdqa_pipeline.dump_reader() would dump the trained retriever too.

File A

from cdqa.pipeline import QAPipeline

cdqa_pipeline = QAPipeline(reader='./models/distilbert_qa.joblib')
cdqa_pipeline.fit_retriever(df=df)

cdqa_pipeline.dump_reader('./models/distilbert_qa_fine_tuned.joblib')

File B

from cdqa.pipeline import QAPipeline

cdqa_pipeline = QAPipeline(reader='./models/distilbert_qa_fine_tuned.joblib')
cdqa_pipeline.fit_retriever(df=df)

# fix:
# cdqa_pipeline = QAPipeline(reader='./models/distilbert_qa.joblib')
# cdqa_pipeline = cdqa_pipeline.fit_retriever(df=df)

while True:
    q = input('Question: ')
    answer = cdqa_pipeline.predict(query=q)
    print(answer)