kbressem / medAlpaca

LLM finetuned for medical question answering
GNU General Public License v3.0
484 stars 56 forks source link

error running the model in google colab #13

Closed Mahgozar closed 1 year ago

Mahgozar commented 1 year ago

i get the following error when i try to use the sample code provided in the model card in google colab the code

from transformers import pipeline
qa_pipeline = pipeline("question-answering", model="medalpaca/medalpaca-13b", tokenizer="medalpaca/medalpaca-13b")
question = "What are the symptoms of diabetes?"
context = "Diabetes is a metabolic disease that causes high blood sugar. The symptoms include increased thirst, frequent urination, and unexplained weight loss."
answer = qa_pipeline({"question": question, "context": context})
print(answer)

the error:


KeyError                                  Traceback (most recent call last)

[<ipython-input-6-936eb75f679c>](https://localhost:8080/#) in <cell line: 2>()
      1 from transformers import pipeline
----> 2 qa_pipeline = pipeline("question-answering", model="medalpaca/medalpaca-13b", tokenizer="medalpaca/medalpaca-13b")
      3 question = "What are the symptoms of diabetes?"
      4 context = "Diabetes is a metabolic disease that causes high blood sugar. The symptoms include increased thirst, frequent urination, and unexplained weight loss."
      5 answer = qa_pipeline({"question": question, "context": context})

2 frames

[/usr/local/lib/python3.9/dist-packages/transformers/models/auto/configuration_auto.py](https://localhost:8080/#) in __getitem__(self, key)
    621             return self._extra_content[key]
    622         if key not in self._mapping:
--> 623             raise KeyError(key)
    624         value = self._mapping[key]
    625         module_name = model_type_to_module_name(key)

KeyError: 'llama'
kbressem commented 1 year ago

Not sure if the HF pipelines already fully support LLaMA models. Try using the medAlpaca inferer for now:

from medalpaca.inferer import Inferer

medalpaca = Inferer("medalpaca/medalapca-7b", "prompt_templates/medalpaca.json")
response = medalpaca(input="What is Amoxicillin?")

This should allow prompting the model and also makes sure, the correct template is used.