ZHZisZZ / weak-to-strong-search

Weak-to-Strong Search: Align Large Language Models via Searching over Small Language Models
https://arxiv.org/abs/2405.19262
9 stars 1 forks source link

Weak-to-Strong Search

Code release for Weak-to-Strong Search: Align Large Language Models via Searching over Small Language Models.

Installation

conda create -n weak-to-strong-search python=3.10
conda activate weak-to-strong-search
pip install torch=2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
# (optional) pip install flash-attn==2.3.2 --no-build-isolation
# (optional) pip install bitsandbytes==0.42.0

Quick Start

(Click to expand) To use HuggingFaceH4/zephyr-7b-beta and its untuned verision HuggingFaceH4/mistral-7b-sft-beta to guide the decoding of meta-llama/Meta-Llama-3-8B-Instruct for better alignment. ```python import torch from transformers import AutoModelForCausalLM, AutoTokenizer from src.inference_time_alignment.decoders.cbs import CBSPosthocGenerationMixin from src.inference_time_alignment.scorers import ImplicitValueScorer def get_zephyr_scorer() -> ImplicitValueScorer: """ Use `zephyr-7b-beta` and its untuned verision `mistral-7b-sft-beta` as scorer to guide other models """ tuned_model = AutoModelForCausalLM.from_pretrained( "HuggingFaceH4/zephyr-7b-beta", torch_dtype=torch.bfloat16, device_map="auto") untuned_model = AutoModelForCausalLM.from_pretrained( "HuggingFaceH4/mistral-7b-sft-beta", torch_dtype=torch.bfloat16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" prompt_template = tokenizer.apply_chat_template( [ {"role": "system", "content": ""}, {"role": "user", "content": "{raw_prompt}"}, ], tokenize=False, add_generation_prompt=True, ) implicit_value_scorer = ImplicitValueScorer( model=tuned_model, ref_model=untuned_model, tokenizer=tokenizer, model_prompt_template=prompt_template, ref_model_prompt_template=prompt_template, ) return implicit_value_scorer # the (stonger/larger) model to be guided base = AutoModelForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") prompt_template = tokenizer.apply_chat_template( [ {"role": "system", "content": ""}, {"role": "user", "content": "{raw_prompt}"}, ], tokenize=False, add_generation_prompt=True, ) # chunk-level beam search wrapper cbs_model = CBSPosthocGenerationMixin(base, tokenizer) # implicit value scorer scorer = get_zephyr_scorer() # prepare prompts raw_prompt = "Who are you?" prompt = prompt_template.format(raw_prompt=raw_prompt) prompt_tokenized = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) prompt_len = prompt_tokenized["input_ids"].size(1) # search for the highest scoring response outputs = cbs_model.search( input_ids=prompt_tokenized["input_ids"].cuda(), attention_mask=prompt_tokenized["attention_mask"].cuda(), scorer=scorer.set_raw_prompt(raw_prompt), split_by_prompt_text=False, w=2, k=2, l=30, # CBS related args max_new_tokens=128, ) print(tokenizer.decode(outputs[0][prompt_len:], skip_special_tokens=True)) ```

See scripts/instruction_following for more examples.

Reference

@article{zhou2024weak,
  title={Weak-to-Strong Search: Align Large Language Models via Searching over Small Language Models},
  author={Zhou, Zhanhui and Liu, Zhixuan and Liu, Jie and Dong, Zhichen and Yang, Chao and Qiao, Yu},
  journal={arXiv preprint arXiv:2405.19262},
  year={2024}
}