Closed achew010 closed 4 months ago
Maybe this can be addressed by distributing the adapters using DDP, just as it was done with the AutoGPTQ version
Turns out this error is thrown because base_layer W
is on cpu
device when passed inside W = fast_dequantize(W.t(), W_quant)
of BNB
's fast_lora.py
. However, dequantization of W
needs to happen on cuda
.
Also, adapters A
and B
were also on cpu
and will subsequently also throw a device mismatch error when matmul with X (which is on cuda
) inside the the matmul_lora function.
fused_ops/unsloth_lora/utils.py
def matmul_lora(X, W, W_quant, A, B, s, out = None):
dtype = X.dtype
W = fast_dequantize(W.t(), W_quant)
if X.dim() == 3:
batch, seq_len, d = X.shape
X = X.view(-1, X.shape[-1])
reshape = True
else:
reshape = False
pass
out = torch.matmul(X, W, out = out)
if W_quant is not None: del W
if A is not None:
# LoRA is enabled
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
pass
return out.view(batch, seq_len, -1) if reshape else out
pass
Just before the foak patching, the model itself has been casted to cuda
but self_attn
base layer weights and adapters are still on cpu
due to low memory mode
(Pdb) model.device
device(type='cuda', index=0)
(Pdb) model.get_base_model().model.layers[0].self_attn.q_proj.base_layer.weight.device
device(type='cpu')
(Pdb) model.get_base_model().model.layers[0].self_attn.q_proj.lora_A.default.weight.device
device(type='cpu')
However, removing the FOAK patching and seems to reverse the problem and FSDP-QLoRA with low memory mode trains perfectly fine
My guess is since the FOAK patch happens before the trainer prepares the model, the patching is performed on weights still residing on cpu
and will subsequently cause problems when self
references to module weights not placed on the gpu.
I made a temporary workaround is to cast the attention module to X.device
when X
is passed in below. This is not the correct solution but it avoids the error.
fused_ops/unsloth_lora/bnb/fast_lora.py
def apply_lora_qkv(self, X):
self = self.to(X.device) # TEMPFIX: adding this will cast the module to device
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
Q, K, V = LoRA_QKV.apply(X,
QW, QW_quant, QA, QB, QS,
KW, KW_quant, KA, KB, KS,
VW, VW_quant, VA, VB, VS,
)
return Q, K, V
pass
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --config_file scripts/benchmarks/accelerate.yaml --num_processes=2 --main_process_port=29500 -m tuning.sft_trainer --model_name_or_path mistralai/Mistral-7B-v0.1 --acceleration_framework_config_file sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml --packing True --max_seq_len 4096 --fp16 True --learning_rate 2e-4 --torch_dtype float16 --peft_method lora --r 16 --lora_alpha 16 --lora_dropout 0.0 --target_modules q_proj k_proj v_proj o_proj --use_flash_attn True --response_template '\n### Response:' --dataset_text_field 'output' --include_tokens_per_second True --num_train_epochs 1 --gradient_accumulation_steps 1 --gradient_checkpointing True --evaluation_strategy no --save_strategy no --weight_decay 0.01 --warmup_steps 10 --adam_epsilon 1e-4 --lr_scheduler_type linear --logging_strategy steps --logging_steps 10 --max_steps 30 --training_data_path benchmark_outputs/data/cache.json --skip_memory_metrics True --per_device_train_batch_size 4 --output_dir benchmark_outputs/exp_1/hf
@achew010 are you sure that this fixes the BNB case, because I realized i was getting the exact same error with the GPTQ case. The reason is because of #26 , where now in low_mem mode, we do not move the whole model directly to GPU, and we also ignore the adapters from FSDPing, so this is the reason why the adapters stayed on CPU.
So I fixed it in #29
@achew010 Update: The root cause is not because of the lora weights staying on cpu, you can try the following:
We can see this after my fix in #29, the base layer weights are on the GPU.
[('q_proj.base_layer.weight', device(type='cpu')), ('q_proj.lora_A.default.weight', device(type='cuda', index=0)), ('q_proj.lora_B.default.weight', device(type='cuda', index=0)), ('k_proj.base_layer.weight', device(type='cpu')), ('k_proj.lora_A.default.weight', device(type='cuda', index=0)), ('k_proj.lora_B.default.weight', device(type='cuda', index=0)), ('v_proj.base_layer.weight', device(type='cpu')), ('v_proj.lora_A.default.weight', device(type='cuda', index=0)), ('v_proj.lora_B.default.weight', device(type='cuda', index=0)), ('o_proj.base_layer.weight', device(type='cpu')), ('o_proj.lora_A.default.weight', device(type='cuda', index=0)), ('o_proj.lora_B.default.weight', device(type='cuda', index=0))]
I think the real issue is because BNB QLoRA does not work with FSDP low memory mode. And I think we need to fix it from the root cause, I feel the workaround is dangerous because in FSDP the parameters are being sharded and deshareded, so putting a .to
in a foward function is not very safe.
@fabianlim you are right, QLoRA doesn't work with FSDP and low memory mode, the weights stay in cpu
until the FSDP wrapping here. Similar to the issue with #29, the base layer wasnt casted because of the FSDP ignored_modules
in FSDP-FOAK workaround. I made a fix here, this is done before Trainer.train
is called and will not interfere with FSDP sharding and unsharding at forward time. Im encountering the same device errors with set_module_tensor_to_device
so im using the cuda
method for now.
For GPTQ, the fix in #29 resolved the casting of adapters in ignored modules to cuda
. The base layer is mapped to a meta
device from model initialization (see in code) and the weights don't materialize
until training.
I was thinking to load the QLoRA weights in meta
device (reference) and maybe adopt the same way GPTQ reinitializes their meta
tensors but i haven't found out how GPTQ exactly does the re-initialization in their code yet.
This has been addressed by #31.
Problem
Distributed experiments in the benchmarks fail when using BNB's
nf4
QLoRA with unsloth fused module optimizations.Cause
Distributed experiments for BNB's
nf4
QLoRA doesnt throw any errors. Suspected incompatibility of FSDP, BNB kernels and Unsloth's matmul.Stacktrace from test repo:
Setting debug environment var
CUDA_LAUNCH_BLOCKING=1
produces thisError an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu
. This is traced to thedequantizeBlockwise
CUDA function.Reproduce