kongds / MoRA

MoRA: High-Rank Updating for Parameter-Efficient Fine-Tuning
https://arxiv.org/abs/2405.12130
Apache License 2.0
331 stars 18 forks source link

The experiments on Instruction Tuning and Continual Pretraining #13

Closed lucasliunju closed 3 months ago

lucasliunju commented 3 months ago

Hi, May I ask how to run the code on Instruction Tuning task and Continual Pretraining.

Thank you very much in advance!

lucasliunju commented 3 months ago

In addition, may I ask how to evaluate the trained model, could you please provide some examples?

kongds commented 3 months ago

Thanks for your interest in our work.

Due to the data of instruction tuning and continual pretraining is large, we tokenize data before training. To run instruction tuning, we use allenai/tulu-v2-sft-mixture with following process code based on https://github.com/allenai/open-instruct:

click to collapse ```python from functools import partial import torch from transformers import AutoTokenizer import datasets def encode_with_messages_format(example, tokenizer, max_seq_length, add_bos=False): ''' Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields. We concatenate all messages with the roles as delimiters and tokenize them together. ''' messages = example['messages'] if len(messages) == 0: raise ValueError('messages field is empty.') def _concat_messages(messages): message_text = "" for message in messages: if message["role"] == "system": message_text += "<|system|>\n" + message["content"].strip() + "\n" elif message["role"] == "user": message_text += "<|user|>\n" + message["content"].strip() + "\n" elif message["role"] == "assistant": message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n" else: raise ValueError("Invalid role: {}".format(message["role"])) return message_text example_text = _concat_messages(messages).strip() tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True) input_ids = tokenized_example.input_ids labels = input_ids.clone() # mask the non-assistant part for avoiding loss for message_idx, message in enumerate(messages): if message["role"] != "assistant": if message_idx == 0: message_start_idx = 0 else: message_start_idx = tokenizer( _concat_messages(messages[:message_idx]), return_tensors='pt', max_length=max_seq_length, truncation=True ).input_ids.shape[1] if message_idx < len(messages) - 1 and messages[message_idx+1]["role"] == "assistant": # here we also ignore the role of the assistant messages_so_far = _concat_messages(messages[:message_idx+1]) + "<|assistant|>\n" else: messages_so_far = _concat_messages(messages[:message_idx+1]) message_end_idx = tokenizer( messages_so_far, return_tensors='pt', max_length=max_seq_length, truncation=True ).input_ids.shape[1] labels[:, message_start_idx:message_end_idx] = -100 if message_end_idx >= max_seq_length: break attention_mask = torch.ones_like(input_ids) return { 'input_ids': input_ids.flatten().tolist(), 'labels': labels.flatten().tolist(), 'attention_mask': attention_mask.flatten().tolist(), } tokenizer = AutoTokenizer.from_pretrained("daryl149/llama-2-7b-hf",use_fast=False) dataset = datasets.load_dataset('allenai/tulu-v2-sft-mixture') max_seq_length = 2048 num_added_tokens = tokenizer.add_special_tokens({ "bos_token": "", "eos_token": "", "unk_token": "", "pad_token": "", }) encode_function = partial( encode_with_messages_format, tokenizer=tokenizer, max_seq_length=max_seq_length, add_bos=add_bos, ) lm_datasets = dataset['train'].map( encode_function, batched=False, num_proc=40, remove_columns=dataset["train"].column_names, desc="Tokenizing and reformatting instruction data", ) lm_datasets = lm_datasets.filter(lambda example: (torch.LongTensor(example['labels']) != -100).any()) lm_datasets.save_to_disk('open-instruct-tokenized') ```

You can then run it with following command

RANK=8
deepspeed --num_gpus=8 --num_nodes=2 train.py \
           --base_model <LLAMA-2> --micro_batch_size 1\
            --wandb_run_name mora_math_r8 --lora_target_modules q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj \
            --num_epochs 2 --deepspeed ds.config --wandb_project lora-instruction --lora_r $RANK --batch_size 128 \
            --data_path open-instruct-tokenized \
            --save_steps 3000 \
            --learning_rate 2e-4 --mora_type 6 \
            --logging_steps 5  --use_bf16  --use_16bit --use_mora \
             --lr_scheduler_type "cosine"  --warmup_steps 150

For continual pretraining, we follow adaptLLM and leverage LLM to process domain data from PubMed abstracts and financial news mixed with instruction data from open-orca.

kongds commented 3 months ago

For evaluation, we first merge the MoRA into the LLM.

To evaluate on math, we directly use eval_gsm8k.py and eval_math.py from https://github.com/meta-math/MetaMath?tab=readme-ov-file#evaluation.

For instruction evaluation, we use a similar script from https://github.com/allenai/open-instruct/blob/main/scripts/eval/mmlu.sh to get MMLU results.

To evaluate on continual pretraining, we collect corresponding tasks and obtain results through in-context learning.

lucasliunju commented 3 months ago

Hi @kongds

Thanks for your detailed response.

I find the training prompt is

"Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"

But the eval prompt in eval_gsm8k.py is

"Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."

I'm not sure whether that will cause some difference.

Thank you very much in advance!

kongds commented 3 months ago

This training prompt is also from MetaMath, you can find it in https://github.com/meta-math/MetaMath/blob/fe667b14f9a4c51bde9809adafb82b6b5d255800/train_math.py#L48-L52

lucasliunju commented 3 months ago

Thanks for your help!

lucasliunju commented 3 months ago

Thanks for your great repo again! I think your code is pretty good, we can also reproduce the performance of other baselines.

I have reproduced the results on metamath and currently I want to reproduce the results on open-instruct tasks. I would like to ask something about the hyper-parametes. I find the hyper-parameters on the appendix of your paper: warmup steps is 500. But in your provided example, warmup steps is 150. I would like to ask whether this warmup step can cause some difference.

Thank you very much for you help in advance.

kongds commented 3 months ago

Sorry for the confusion.

I checked the experiments and paper, and it seems that the warmup steps mentioned in the paper are incorrect (I will fix it in next version). The warmup steps in our experiments are 150. The 150 warmup steps are from open-instruction. They use a 0.03 warmup ratio. Since the training steps are 5000, we use 150 warmup steps.

Regarding the influence of different warmup steps, I haven't searched it before. It seems that smaller warmup steps lead to faster convergence compared to larger steps, but I don't know how it influences the final results.

lucasliunju commented 3 months ago

Thanks for your quick reply, I will try it.

lucasliunju commented 3 months ago

Thanks for your interest in our work.

Due to the data of instruction tuning and continual pretraining is large, we tokenize data before training. To run instruction tuning, we use allenai/tulu-v2-sft-mixture with following process code based on https://github.com/allenai/open-instruct:

click to collapse

from functools import partial
import torch
from transformers import AutoTokenizer
import datasets

def encode_with_messages_format(example, tokenizer, max_seq_length, add_bos=False):
    '''
    Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
    We concatenate all messages with the roles as delimiters and tokenize them together.
    '''
    messages = example['messages']
    if len(messages) == 0:
        raise ValueError('messages field is empty.')

    def _concat_messages(messages):
        message_text = ""
        for message in messages:
            if message["role"] == "system":
                message_text += "<|system|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "user":
                message_text += "<|user|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "assistant":
                message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n"
            else:
                raise ValueError("Invalid role: {}".format(message["role"]))
        return message_text

    example_text = _concat_messages(messages).strip()
    tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
    input_ids = tokenized_example.input_ids
    labels = input_ids.clone()

    # mask the non-assistant part for avoiding loss
    for message_idx, message in enumerate(messages):
        if message["role"] != "assistant":
            if message_idx == 0:
                message_start_idx = 0
            else:
                message_start_idx = tokenizer(
                    _concat_messages(messages[:message_idx]), return_tensors='pt', max_length=max_seq_length, truncation=True
                ).input_ids.shape[1]
            if message_idx < len(messages) - 1 and messages[message_idx+1]["role"] == "assistant":
                # here we also ignore the role of the assistant
                messages_so_far = _concat_messages(messages[:message_idx+1]) + "<|assistant|>\n"
            else:
                messages_so_far = _concat_messages(messages[:message_idx+1])
            message_end_idx = tokenizer(
                messages_so_far,
                return_tensors='pt',
                max_length=max_seq_length,
                truncation=True
            ).input_ids.shape[1]
            labels[:, message_start_idx:message_end_idx] = -100

            if message_end_idx >= max_seq_length:
                break

    attention_mask = torch.ones_like(input_ids)
    return {
        'input_ids': input_ids.flatten().tolist(),
        'labels': labels.flatten().tolist(),
        'attention_mask': attention_mask.flatten().tolist(),
    }

tokenizer = AutoTokenizer.from_pretrained("daryl149/llama-2-7b-hf",use_fast=False)

dataset = datasets.load_dataset('allenai/tulu-v2-sft-mixture')
max_seq_length = 2048
num_added_tokens = tokenizer.add_special_tokens({
    "bos_token": "<s>",
    "eos_token": "</s>",
    "unk_token": "<unk>",
    "pad_token": "<pad>",
})
encode_function = partial(
    encode_with_messages_format,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
    add_bos=add_bos,
)

lm_datasets = dataset['train'].map(
    encode_function,
    batched=False,
    num_proc=40,
    remove_columns=dataset["train"].column_names,
    desc="Tokenizing and reformatting instruction data",
)
lm_datasets = lm_datasets.filter(lambda example: (torch.LongTensor(example['labels']) != -100).any())
lm_datasets.save_to_disk('open-instruct-tokenized')

You can then run it with following command

RANK=8
deepspeed --num_gpus=8 --num_nodes=2 train.py \
           --base_model <LLAMA-2> --micro_batch_size 1\
            --wandb_run_name mora_math_r8 --lora_target_modules q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj \
            --num_epochs 2 --deepspeed ds.config --wandb_project lora-instruction --lora_r $RANK --batch_size 128 \
            --data_path open-instruct-tokenized \
            --save_steps 3000 \
            --learning_rate 2e-4 --mora_type 6 \
            --logging_steps 5  --use_bf16  --use_16bit --use_mora \
             --lr_scheduler_type "cosine"  --warmup_steps 150

For continual pretraining, we follow adaptLLM and leverage LLM to process domain data from PubMed abstracts and financial news mixed with instruction data from open-orca.

Hi, I find we still have not define add_bos and I find the default add_bos=False cannot work and I guess that should be True. I would like to ask whether add_bos should be True.

kongds commented 3 months ago

Hi, I forgot to remove some code related to add_bos. (Setting add_bos to False should work.) Here is the revised version. which should run without errors.

from functools import partial
import torch
from transformers import AutoTokenizer
import datasets

def encode_with_messages_format(example, tokenizer, max_seq_length, add_bos=False):
    '''
    Here we assume each example has a 'messages' field Each message is a dict with 'role' and 'content' fields.
    We concatenate all messages with the roles as delimiters and tokenize them together.
    '''
    messages = example['messages']
    if len(messages) == 0:
        raise ValueError('messages field is empty.')

    def _concat_messages(messages):
        message_text = ""
        for message in messages:
            if message["role"] == "system":
                message_text += "<|system|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "user":
                message_text += "<|user|>\n" + message["content"].strip() + "\n"
            elif message["role"] == "assistant":
                message_text += "<|assistant|>\n" + message["content"].strip() + tokenizer.eos_token + "\n"
            else:
                raise ValueError("Invalid role: {}".format(message["role"]))
        return message_text

    example_text = _concat_messages(messages).strip()
    tokenized_example = tokenizer(example_text, return_tensors='pt', max_length=max_seq_length, truncation=True)
    input_ids = tokenized_example.input_ids
    labels = input_ids.clone()

    # mask the non-assistant part for avoiding loss
    for message_idx, message in enumerate(messages):
        if message["role"] != "assistant":
            if message_idx == 0:
                message_start_idx = 0
            else:
                message_start_idx = tokenizer(
                    _concat_messages(messages[:message_idx]), return_tensors='pt', max_length=max_seq_length, truncation=True
                ).input_ids.shape[1]
            if message_idx < len(messages) - 1 and messages[message_idx+1]["role"] == "assistant":
                # here we also ignore the role of the assistant
                messages_so_far = _concat_messages(messages[:message_idx+1]) + "<|assistant|>\n"
            else:
                messages_so_far = _concat_messages(messages[:message_idx+1])
            message_end_idx = tokenizer(
                messages_so_far,
                return_tensors='pt',
                max_length=max_seq_length,
                truncation=True
            ).input_ids.shape[1]
            labels[:, message_start_idx:message_end_idx] = -100

            if message_end_idx >= max_seq_length:
                break

    attention_mask = torch.ones_like(input_ids)
    return {
        'input_ids': input_ids.flatten().tolist(),
        'labels': labels.flatten().tolist(),
        'attention_mask': attention_mask.flatten().tolist(),
    }

tokenizer = AutoTokenizer.from_pretrained("daryl149/llama-2-7b-hf",use_fast=False)

dataset = datasets.load_dataset('allenai/tulu-v2-sft-mixture')
max_seq_length = 2048
num_added_tokens = tokenizer.add_special_tokens({
    "bos_token": "<s>",
    "eos_token": "</s>",
    "unk_token": "<unk>",
    "pad_token": "<pad>",
})
encode_function = partial(
    encode_with_messages_format,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length,
 )

lm_datasets = dataset['train'].map(
    encode_function,
    batched=False,
    num_proc=40,
    remove_columns=dataset["train"].column_names,
    desc="Tokenizing and reformatting instruction data",
)
lm_datasets = lm_datasets.filter(lambda example: (torch.LongTensor(example['labels']) != -100).any())
lm_datasets.save_to_disk('open-instruct-tokenized')
lucasliunju commented 3 months ago

Thanks for your reply. I will try it.

lucasliunju commented 3 months ago

Hi, I still encounter an error at 904 step

| 904/5010 [5:33:04<25:04:33, 21.99s/it]../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [835,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [835,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [835,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [835,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [835,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [835,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [835,0,0], thread: [6,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [835,0,0], thread: [7,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
...
with exception: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I'm not sure whether this is because the cuda version is not match.

lucasliunju commented 3 months ago

Hi, I've switched to using another cluster, and it has successfully run over 1000 steps, which looks promising. I'll check if there are any similar errors to what we encountered before. Interestingly, although they use the term random_seed I've noticed that the loss curves are different. I suspect that the set_seed function may not guarantee full reproducibility of results.

kongds commented 3 months ago

For this error, we set the out index token to 0 if the data_path contains open-instruct-tokenized in following. If you rename the path, which may cause this error for not replacing this out index token. https://github.com/kongds/MoRA/blob/0ff64b144e60b54fe7c0ff7b4e76c99c949e923d/train.py#L624-L639

For the reproducibility, set_seed should works in my testing. The difference may cause from the environment or different micro_batch_size.

lucasliunju commented 3 months ago

Hi, thanks for your reply.

I just check my data path and I use the same path as open-instruct-tokenized .

kongds commented 3 months ago

I check the codes and experiments again. I forget add --new_pad_token in command, following command should work. Sorry for this mistake.

RANK=8
deepspeed --num_gpus=8 --num_nodes=2 train.py \
           --base_model <LLAMA-2> --micro_batch_size 1\
            --wandb_run_name mora_math_r8 --lora_target_modules q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj \
            --num_epochs 2 --deepspeed ds.config --wandb_project lora-instruction --lora_r $RANK --batch_size 128 \
            --data_path open-instruct-tokenized \
            --save_steps 3000 \
            --learning_rate 2e-4 --mora_type 6 \
            --logging_steps 5  --use_bf16  --use_16bit --use_mora \
             --lr_scheduler_type "cosine"  --warmup_steps 150 --new_pad_token
kongds commented 3 months ago

This error is caused by the tokenizer using 32000 as the pad token. However, there is text containing "< pad >" in open-instruction, which results in a token with an index of 32000.

Using --new_pad_token to resize the token embedding from 32000 to 32001 can solve this problem.

lucasliunju commented 3 months ago

Hi,

Thanks for your reply. So I just need to add --new_pad_token and don't need to change the code of train.py.

kongds commented 3 months ago

Yes

lucasliunju commented 3 months ago

Thanks, I will try it.

lucasliunju commented 3 months ago

Hi, I find the result is quite different when I use single GPU and 2 gpus to run the code. In addition, I find the loss value is very similar if we use two gpus with different seeds (That is not normal because we hope get different result with different seeds). But, if we use single GPU, the loss value with different seeds is different (normal, the results are different). I think the main reason is that set_seed does not work when we use two gpus.

lucasliunju commented 3 months ago

Sorry I have so many questions.

kongds commented 3 months ago

It's okay. Thanks for reproducing our work.

Regarding the issue with seed, we use the seed not only in set_seed but also in transformers.TrainingArguments. For transformers.TrainingArguments, the seed is set within the transformer library again to ensure it reproduce. https://github.com/kongds/MoRA/blob/0ff64b144e60b54fe7c0ff7b4e76c99c949e923d/train.py#L673

I think the difference may be caused by different gradient_accumulation_steps. The single GPU uses a large gradient_accumulation_steps, which may cause the training different.

lucasliunju commented 3 months ago

Hi, thanks for your reply.

I find the loss curve of multi-gpu training is very similar when I use different seeds. That make me consider whether we really set the seed for multi-gpu training. But when I use single GPU to train the model with different seeds, the loss value is different. That means the random seed can work for single GPU but not for multi-gpu. I guess maybe that is from the random of deepspeed.

kongds commented 3 months ago

We use the trainer from transformer and the seed is also set by transformer. I think it is compatible with deepspeed.

lucasliunju commented 3 months ago

Hi, thanks for your reply.

I'm trying to reproduce the results and let you know when I finish it.

kongds commented 3 months ago

Thank you. But I still think the seed is correctly set even with multi GPUs (we train it on 32 GPUs, and the seed works for us). And the reason between this difference is not the seed.

Regarding the difference in loss, I think it is acceptable for the loss to be similar even if we use different seeds. This is because we use the same type of data, and the losses are averaged with a sufficiently large batch size, such as 128.

lucasliunju commented 3 months ago

Hi, thanks for your reply.

Do you try to use different seeds to run the code when you use 32 GPUs? I try to change the seed and I find the loss curve is very similar. That means we can get same loss value with different seeds.

This is my loss curve with different seeds when use 2 gpus, each device has one data:

Weixin Image_20240709144358
lucasliunju commented 3 months ago

But when I use single gpu to train and we can get different loss values with different seeds.

kongds commented 3 months ago
image

Here is my training loss with different seeds, each runs on 16 GPUs.

lucasliunju commented 3 months ago

So sad. Let me check my environment.

Thanks for your help.

lucasliunju commented 3 months ago

Hi I have finished the training and I find the test accuracy of LoRA and MoRA on MMLU-0 are about 49.2, 49.4. I'm not sure whether these results are correct since I find the reported result of LoRA in the paper is 50.2.

kongds commented 3 months ago

Hello, It seems the settings in your experiment are the same as ours (we also use a 2e-4 learning rate to run LoRA in our paper). However, the gradient_accumulation_steps in your experiment differ from ours, I am not sure weather this may cause the different results of LoRA. For example, we use a micro batch size of 1 with 16 GPUs, and the gradient_accumulation_steps is set to 8. But it seems you use 2 GPUs, where the gradient_accumulation_steps can be set to 64.

lucasliunju commented 3 months ago

Thanks for your detailed reply.

Yes, let me try to use more GPUs.

lucasliunju commented 3 months ago

Hi, may I ask the best learning rate for MoRA and LoRA.

kongds commented 3 months ago

In our experiment, 2e-4 seems better for rank 8 of MoRA and LoRA on instruction tuning.

lucasliunju commented 1 week ago

Hi, may I ask a question which is not related MoRA. I am trying to use this codebase to fine-tune llama3-8b, but I find I cannot use vllm to inference and evaluate it if I directly use the inference code from metamath repo. It could be better if you could give me some advice. Thank you very much in advance.

kongds commented 1 week ago

Hello, do you merge the mora or lora before inference?

lucasliunju commented 1 week ago

yes, I have merged lora or mora to the base model and then inference and evaluate with vllm.

lucasliunju commented 1 week ago

but llama2 can work well.

kongds commented 1 week ago

It seems to be a problem with vllm. Maybe you need to upgrade vllm to support llama3.

lucasliunju commented 4 days ago

Thanks for your reply! I'll have a try. By the way, I find the training performance is quite different if I change the version of transformers and peft.