foundation-model-stack / fms-acceleration

🚀 Collection of libraries used with fms-hf-tuning to accelerate fine-tuning and training of large models.
Apache License 2.0
0 stars 4 forks source link

Fix FSDP when performing GPTQ-LoRA with Triton V2 #15

Closed fabianlim closed 1 month ago

fabianlim commented 1 month ago

GPTQ-LoRa depends on the AutoGPTQ package, but there are issues that prevent the base GPTQ model from being FSDPed

The issue comes from that QuantLinear class stores the parameters (i.e. qweight, qzeros) in torch.int32, which results in

The fix is to then use torch.tensor.view, which does a C++ style reinterpret cast in QuantLinear.forward before call in the the QuantLinearFunction autograd function. We create the nn.Parameter in the same vein, by doing a qweight.view(torch_type) to force the parameter to be a of torch_type (which is going to be a float type)

Reproduce

To reproduce the fix, consider the command

export CUDA_VISIBLE_DEVICES=0,1
accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29501 \
-m tuning.sft_trainer --model_name_or_path TheBloke/Nous-Hermes-Llama2-70B-GPTQ --acceleration_framework_config_file /data/flim/fms-acceleration-oss/scripts/benchmarks/../../sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml --packing True --max_seq_len 4096 --learning_rate 2e-4 --fp16 True --torch_dtype float16 --peft_method lora --r 16 --lora_alpha 16 --lora_dropout 0.0 --target_modules q_proj k_proj v_proj o_proj --use_flash_attn True --response_template '
### Response:' --dataset_text_field output --include_tokens_per_second True --num_train_epochs 1 --gradient_accumulation_steps 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --max_steps 100 --training_data_path benchmark_outputs/data/cache.json --per_device_train_batch_size 2 --output_dir benchmark_outputs/exp_1/hf
fix prepare forward backward
N 35 B 48 B 36 B
Y 18 B 32 B 19 B

Losses and Throughputs AFTER FIX

{'loss': 1.0727, 'grad_norm': 0.09820556640625, 'learning_rate': 0.0002, 'epoch': 0.01}
{'loss': 0.9477, 'grad_norm': 0.19384765625, 'learning_rate': 0.00017777777777777779, 'epoch': 0.03}
{'loss': 0.9168, 'grad_norm': 0.07080078125, 'learning_rate': 0.00015555555555555556, 'epoch': 0.04}
{'loss': 0.9182, 'grad_norm': 0.0616455078125, 'learning_rate': 0.00013333333333333334, 'epoch': 0.06}
{'loss': 0.8815, 'grad_norm': 0.056671142578125, 'learning_rate': 0.00011111111111111112, 'epoch': 0.07}
{'loss': 0.9014, 'grad_norm': 0.058685302734375, 'learning_rate': 8.888888888888889e-05, 'epoch': 0.08}
{'loss': 0.8754, 'grad_norm': 0.06451416015625, 'learning_rate': 6.666666666666667e-05, 'epoch': 0.1}
{'loss': 0.8863, 'grad_norm': 0.05322265625, 'learning_rate': 4.4444444444444447e-05, 'epoch': 0.11}
{'loss': 0.8929, 'grad_norm': 0.0596923828125, 'learning_rate': 2.2222222222222223e-05, 'epoch': 0.13}
{'loss': 0.8821, 'grad_norm': 0.057037353515625, 'learning_rate': 0.0, 'epoch': 0.14}
{'train_runtime': 1882.4921, 'train_samples_per_second': 0.212, 'train_steps_per_second': 0.053, 'train_tokens_per_second': 435.168, 'train_loss': 0.9174925422668457, 'epoch': 0.14}

TODO:

fabianlim commented 1 month ago

so because of the casting we are facing this error in #25 now


  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl

    return forward_call(*args, **kwargs)

  File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py", line 75, in _forward_q

    _fused_op(X)

  File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py", line 64, in _fused_op

    Q, K, V = fused_operation(attn, X)

  File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py", line 620, in apply_lora_qkv

    Q, K, V = LoRA_QKV.apply(

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply

    return super().apply(*args, **kwargs)  # type: ignore[misc]

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd

    return fwd(*args, **kwargs)

  File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py", line 464, in forward

    QW = dequant248(Q_qweight, Q_scales, Q_qzeros, Q_g_idx, Q_bits)

  File "/workspace/fms-acceleration/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/kernels.py", line 137, in dequant248

    dequant_kernel_248[grid](

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in run

    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>

    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 122, in _bench

    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/testing.py", line 102, in do_bench

    fn()

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 110, in kernel_call

    self.fn.run(

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/runtime/jit.py", line 532, in run

    self.cache[device][key] = compile(

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/compiler/compiler.py", line 543, in compile

    next_module = compile_kernel(module)

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/compiler/compiler.py", line 435, in <lambda>

    ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))

  File "/workspace/fms-acceleration/.tox/run-benches/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1237, in ast_to_ttir

    raise CompilationError(fn.src, node, repr(e)) from e