jxiw / MambaInLlama

Official Repository of The Mamba in the Llama: Distilling and Accelerating Hybrid Models
https://arxiv.org/abs/2408.15237
Apache License 2.0
169 stars 12 forks source link

Could you share your code to generate "pseudo labels" #6

Closed tianshu-zhu closed 1 month ago

tianshu-zhu commented 1 month ago

Could you share your code to generate the input_ids.pt and labels.pt using UltraChat and UltraFeedback datasets? As mentinoed in your README:

Generate pseudo labels from a teacher model meta-llama/Meta-Llama-3-8B-Instruct. We provide the generated pseudo labels using the seed dataset of the UltraChat and UltraFeedback dataset here. Please download it and change the train_datasets_path in llama3_0.25_mamba.yaml and llama3_0.50_mamba.yaml to the path of your downloaded llama3_ultrafeedback and llama3_ultrachat.

jxiw commented 1 month ago

Hi, thank you to raise this question.

there are multiple ways to do this.

Stage1: generate teacher model outputs using ultrachat as seed prompt.

  1. Use HF transformers, which might be slow.
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="cuda")
split = "train"
max_len = 2000
max_gen_length=2048

# use the ultrafeedback as seed dataset
prefix = "ultrafeedback"
raw_dataset = load_dataset(f"HuggingFaceH4/{prefix}_binarized", split=f"{split}_gen")

# prefix = "ultrachat"
# raw_dataset = load_dataset(f"HuggingFaceH4/{prefix}_200k", split=f"{split}_gen")

messages = raw_dataset['messages']
formatted_prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
prompts = [tokenizer.encode(formatted_prompt, return_tensors="pt", truncation=True, max_length=max_len) for formatted_prompt in formatted_prompts]
tokenizer.pad_token_id = tokenizer.eos_token_id
batch_size = 1 

gen_labels = []
generated_texts = []

temperature = 0.7
p = 0.95
k = 10

for i in tqdm(range(0, len(prompts), batch_size)):
    batch_prompts = torch.cat(prompts[i:i + batch_size], dim=0).to(model.device)
    outputs = model.generate(batch_prompts, max_length=max_gen_length, do_sample=True, temperature=temperature, top_k=k, top_p=p)
    outputs = pad_tensor(outputs, tokenizer.pad_token_id, max_gen_length)
    gen_labels.append(outputs)
    for output in outputs:
        generated_text = tokenizer.decode(output, skip_special_tokens=True)
        generated_texts.append(generated_text)

tensor_output_file = f"llama3_{temperature}_{k}_{p}_{split}_input_ids.pt"
padded_tensors = torch.cat(gen_labels, dim=0)
torch.save(padded_tensors, tensor_output_file)
  1. Use vLLM
temperature = 0.7
p = 0.95
k = 10

# create prompts same as HF
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
llm = LLM(model=model_name)
sampling_params = SamplingParams(temperature=temperature, top_p=p, top_k=k)
outputs = llm.generate(prompts, sampling_params)
  1. Use together AI. This is not free, so we recommend you to use vLLM.
import os
from together import Together

prefix = "ultrafeedback"
raw_dataset = load_dataset(f"HuggingFaceH4/{prefix}_binarized", split=f"{split}_gen")

#prefix = "ultrachat"
#raw_dataset = load_dataset(f"HuggingFaceH4/{prefix}_200k", split=f"{split}_gen")
messages = raw_dataset['messages']

client = Together(api_key=os.environ.get("TOGETHER_API_KEY"))

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
stream = client.chat.completions.create(
  model=model_name,
  messages=messages,
  stream=True,
) 

Overall, the inference parameter that we use is

temperature = 0.7
p = 0.95
k = 10

As long as you do the similar things, it won't change a lot.

Stage2: produce labels

Once you get input_ids, we want to mask the user and system information and only ask model to predict the assistant part information. So for labels, the assistant part information will keep the same and the rest is masked and set to -100. here is a simple function for that (i admit that it looks a bit ugly). If HF has better API, please let me know.

def mask_pattern(tensor, start_pattern, end_pattern, IGNORE_INDEX):
    mask = torch.full(tensor.shape, IGNORE_INDEX)
    start_pattern_length = len(start_pattern)
    end_pattern_length = len(end_pattern)

    i = 0
    in_sequence = False
    start_index = 0

    # Single loop through the tensor
    while i < len(tensor):
        # Check if current segment matches the start pattern
        if not in_sequence and i <= len(tensor) - start_pattern_length and torch.equal(tensor[i:i+start_pattern_length], start_pattern):
            in_sequence = True
            start_index = i + start_pattern_length  # Begin marking after the start pattern
            i += start_pattern_length - 1  # Skip to end of pattern

        # Check if current segment matches the end pattern and we are in a sequence
        elif in_sequence and i <= len(tensor) - end_pattern_length and torch.equal(tensor[i:i+end_pattern_length], end_pattern):
            mask[start_index:i] = tensor[start_index:i]
            in_sequence = False  # Reset sequence flag
            i += end_pattern_length - 1  # Skip to end of pattern

        i += 1

    if in_sequence:
        mask[start_index:] = tensor[start_index:]

    return mask

And here is a simple use for Llama-3 models.

eos_token_id = 2 # this is corresponding the '<|eot_id|>'
mask_start_pattern = torch.tensor([489, 11143, 28766, 28767, 13]) # this is corresponding the 'assistant<|end_header_id|>'
mask_end_pattern = torch.tensor([eos_token_id])

# this function extract everything between 'assistant<|end_header_id|>' and '<|eot_id|>', which is the model output and set the rest position to -100. Thus when we compute the cross entropy loss, the rest position will be ignored.
training_label = [mask_pattern(tokenized_ids, mask_start_pattern, mask_end_pattern, IGNORE_INDEX=-100).unsqueeze(0) for tokenized_ids in tqdm(training_data)]
training_label = torch.concat(training_label, dim=0)
torch.save(training_label, 'labels.pt')
tianshu-zhu commented 1 month ago

Thank you so much!