dinobby / ReConcile

MIT License
168 stars 12 forks source link

Running ReConcile with Llama 2 Llm #2

Open jabogithub opened 11 months ago

jabogithub commented 11 months ago

I would like running ReConcile with Llama 2 Llm, such as Llama 2 7B or 13B.

dinobby commented 11 months ago

Hi! Thanks for the comments, we have tried Llama2 70B within ReConcile recently, and will update the code to be more flexible in choosing which models to use. For now, there are some scripts you may find useful:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" 
os.environ["WORLD_SIZE"] = "1"

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    BitsAndBytesConfig,
)
import torch

name = "meta-llama/Llama-2-70b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(name)
tokenizer.pad_token_id = tokenizer.eos_token_id    # for open-ended generation

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
    name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
generation_pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    trust_remote_code=True,
    device_map="auto", 
)

Instructions for Llama2:

def prepare_context_for_llama(test_sample, convincing_samples=None, additional_instruc=None, intervene=False, dataset="SQA"):
    context = []
    if convincing_samples:
        for cs in convincing_samples:
            context.append(f"Q: {cs['train_sample']['question']}\nA:" + str({"reasoning": cs['train_sample']['gold_explanation'], "answer": cs['train_sample']['answer']}))

    if intervene:
        context.append("Q: " + test_sample['question'] + "\nAnswer the question given the fact that " + test_sample['gold_explanation'])
    else:
        context.append("Q: " + test_sample['question'])

    context = "\n".join(context)
    if additional_instruc:
        context += " ".join(additional_instruc)

    llama_prefix = '''<s>[INST] <<SYS>>
                        You are a helpful assistant. Always answer as helpfully as possible.
                        Please answer user's question with step-by-step reasoning.
                        Also, evaluate your confidence level (between 0.0 and 1.0) to indicate the possibility of your answer being right.
                        Output your answer in JSON format, with the format as follows: {\"reasoning\": \"\", \"answer\": \"\", \"confidence_level\": \"\"}. Please strictly output in JSON format. Put all your reasoning in the \"reasoning\" field.
                        Only answer yes or no in the \"answer\" field.
                        Do not output irrelevant content.
                        If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
                        <</SYS>>'''
    return llama_prefix + context + " [/INST]"

Initial Response Generation:

llama_result = []
for test_sample in tqdm(test_samples):
    tmp = {}
    tmp['gold_answer'] = test_sample['answer']
    text = prepare_context_for_llama(test_sample, convincing_samples=convincing_claude+convincing_gpt)

    sequences = generation_pipe(
        text,
        max_length=4096,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=True,
        top_k=10,
        temperature=0.7,
        top_p=0.9
    )
    tmp['prediction'] = parse_json(sequences[0]["generated_text"])
    llama_result.append(tmp)

Multi-Round Discussion:

@backoff.on_exception(backoff.expo, (ValueError, TypeError), max_tries=5)
def llama_gen_ans(sample, convincing_samples=None, additional_instruc=None, intervene=False, dataset="SQA"):
    context = prepare_context_for_llama(sample, convincing_samples, additional_instruc, intervene, dataset) 

    sequences = generation_pipe(
        context,
        max_length=4096,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        do_sample=True,
        top_k=10,
        temperature=0.7,
        top_p=0.9)

    result = sequences[0]["generated_text"].split(context)[-1]
    result = parse_json(result)
    if result == "ERR_SYNTAX":
        raise ValueError("incomplete JSON format.") 
    if not result['answer'] or len(result['answer']) > 3:
        raise ValueError("the generated answer is invalid.")

    result['answer'] = result['answer'].lower()
    return result

def llama_debate(test_samples, all_results, rounds, convincing_samples, dataset="SQA"):
    r = '_' + str(rounds-1)
    for i, s in tqdm(enumerate(all_results)):
        if 'bard_output_'+str(rounds) not in s and 'debate_prompt'+ r in s and len(s['debate_prompt'+r]): 
            additional_instruc = ["\n\nCarefully review the following solutions from other agents as additional information, and provide your own answer and step-by-step reasoning to the question."]
            additional_instruc.append("Clearly states that which pointview do you agree or disagree and why.\n\n")
            additional_instruc.append(s['debate_prompt'+r])
            additional_instruc.append("Output your answer in json format, with the format as follows: {\"reasoning\": \"\", \"answer\": \"\", \"confidence_level\": \"\"}. Please strictly output in JSON format.")
            additional_instruc.append("Place all your reasonninig, the part you agree or disagree and why in the \"reasoning\" field.")
            additional_instruc.append("Do not output anything outside the json format.")

            result = llama_gen_ans(test_samples[i],
                                  convincing_samples=convincing_samples,
                                  additional_instruc=additional_instruc,
                                  intervene=False,
                                  dataset=dataset)

            s['bard_output_'+str(rounds)] = result

    return all_results

We find using Llama 2 70B to replace Bard works well on StrategyQA too. We haven't tried other model size or replacing other agents, but we will do some refactorization for the code to make it easier to test them! 🤓

For the second question, I don't think Llama2 70B can fit google colab, but you can try smaller ones, or if you are using ChatGPT, Bard and Claude2, then it's perfectly fine to use google colab 🙌