This includes the original implementation of SELF-RAG: Learning to Retrieve, Generate and Critique through self-reflection by Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi.
def format_prompt(input, paragraph=None):
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
if paragraph is not None:
prompt += "[Retrieval]{0}".format(paragraph)
return prompt
query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]
for a query that doesn't require retrieval
preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
print("Model prediction: {0}".format(pred.outputs[0].text))
Tried the quick start code in free tier Google Colab. Got out of memory error. Is this expected ?
from vllm import LLM, SamplingParams
model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half") sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)
def format_prompt(input, paragraph=None): prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input) if paragraph is not None: prompt += "[Retrieval]{0} ".format(paragraph)
return prompt
query_1 = "Leave odd one out: twitter, instagram, whatsapp." query_2 = "Can you tell me the difference between llamas and alpacas?" queries = [query_1, query_2]
for a query that doesn't require retrieval
preds = model.generate([format_prompt(query) for query in queries], sampling_params) for pred in preds: print("Model prediction: {0}".format(pred.outputs[0].text))