yangjianxin1 / Firefly

Firefly: 大模型训练工具,支持训练Qwen2、Yi1.5、Phi-3、Llama3、Gemma、MiniCPM、Yi、Deepseek、Orion、Xverse、Mixtral-8x7B、Zephyr、Mistral、Baichuan2、Llma2、Llama、Qwen、Baichuan、ChatGLM2、InternLM、Ziya2、Vicuna、Bloom等大模型
5.24k stars 474 forks source link

训练llama3-8b-it报错 #256

Open wx971025 opened 1 month ago

wx971025 commented 1 month ago

我严格按照README安装了相关的包 pip install -r requirements.txt pip install git+https://github.com/unslothai/unsloth.git pip install bitsandbytes==0.43.1 pip install peft==0.10.0 pip install torch==2.2.2 pip install xformers==0.0.25.post1 启动参数

export CUDA_VISIBLE_DEVICES=2,3,4,5,6,7
torchrun --nproc_per_node=6 train.py \
               --train_args_file train_args/sft/qlora/llama3-8b-sft-qlora.json

我使用了6xA800 ValueError: You can't train a model that has been loaded in 8-bit precision on a different device than the one you're training on. Make sure you loaded the model on th e correct device using for exampledevice_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}` 这个错误是为什么呢

Parasolation commented 1 month ago

能看看train_args么

wx971025 commented 1 month ago

能看看train_args么

{
    "output_dir": "output/firefly-llama3-8b-sft-qlora",
    "model_name_or_path": "/data1/models/llms/llama3_8b_it",
    "train_file": "./data/llama3/dummy_data.jsonl",
    "template_name": "llama3",
    "num_train_epochs": 2,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 1,
    "learning_rate": 2e-4,
    "max_seq_length": 1024,
    "logging_steps": 100,
    "save_steps": 100,
    "save_total_limit": 1,
    "lr_scheduler_type": "constant_with_warmup",
    "warmup_steps": 100,
    "lora_rank": 64,
    "lora_alpha": 16,
    "lora_dropout": 0.05,
    "use_unsloth": true,

    "gradient_checkpointing": true,
    "disable_tqdm": false,
    "optim": "paged_adamw_32bit",
    "seed": 42,
    "fp16": true,
    "report_to": "tensorboard",
    "dataloader_num_workers": 10,
    "save_strategy": "steps",
    "weight_decay": 0,
    "max_grad_norm": 0.3,
    "remove_unused_columns": false
}

似乎是accelerate本身的问题,我不用unsloth就没有问题。