kongds / MoRA

MoRA: High-Rank Updating for Parameter-Efficient Fine-Tuning
https://arxiv.org/abs/2405.12130
Apache License 2.0
341 stars 20 forks source link

可以提供复现UUID的数据集和实验配置吗 #18

Closed 2018211801 closed 2 months ago

2018211801 commented 2 months ago

你好呀,很优秀的工作。我想follow你的工作,可不可以提供一下你的UUID数据集,prompt以及超参等信息呢?非常感谢~~

kongds commented 2 months ago

你好,感谢关注我们的工作 我们使用如下的方式构建数据集

import uuid
import datasets
from transformers import AutoTokenizer
keys = []
values = []

for i in range(10000):
    keys.append(str(uuid.uuid4()))
    values.append(str(uuid.uuid4()))

dataset = datasets.Dataset.from_dict({"key": keys, "value": values})
def tokenize(entry):
    result = tokenizer(
        entry['key'] + ':' + entry['value'],
    )
    result['labels'] = result['input_ids']
    return result
dataset = dataset.map(tokenize, num_proc=40)
dataset.save_to_disk("data/uuid10k-tokenized")

训练的话可以参考如下脚本 lora r256

deepspeed --num_gpus=8 --num_nodes=1 train.py \
        --micro_batch_size 128 --wandb_run_name lora-r256 \
        --num_epochs 100 --wandb_project lora-memory --batch_size 1024 \
        --data_path data/uuid10k-tokenized --logging_steps 1 \
        --lora_target_modules q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj \
        --lora_r 256 --lora_alpha 128 --warmup_steps 100  \
        --learning_rate 3e-4  --grad_checkpoint

mora r256

deepspeed --num_gpus=8 --num_nodes=1 train.py \
        --micro_batch_size 128 --wandb_run_name mora-r256 \
        --num_epochs 100 --wandb_project lora-memory --batch_size 1024 \
        --data_path data/uuid10k-tokenized --logging_steps 1 \
        --lora_target_modules q_proj,k_proj,v_proj,o_proj,gate_proj,down_proj,up_proj \
        --lora_r 256 --warmup_steps 100  \
        --learning_rate 5e-5 --grad_checkpoint --use_mora
2018211801 commented 2 months ago

好滴,非常非常感谢呀~~