Closed l3utterfly closed 10 months ago
File "/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_varlen_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
This error is primarily due to that flash attention does not support head mask. So we actually did not use flash attention during sparsification and then use it during distillation.
https://github.com/GeneZC/MiniMA/commit/01891a5827e3b14d9180127a6f824db37011e9fd
We have updated the codebase to not use flash attention for sparsification and you could have a try now.
Thanks for the update. The sparsification now starts. However, I'm getting an out of memory error running on one A100 GPU with 80GB VRAM:
***** Running sparsification (w. sanity check) *****
Traceback (most recent call last):
File "/root/MiniMA/minima/run_sparsification_llama.py", line 230, in <module>
main()
File "/root/MiniMA/minima/run_sparsification_llama.py", line 142, in main
output = model(**batch, head_mask=head_mask, neuron_mask=neuron_mask, hidden_mask=hidden_mask)
File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/MiniMA/minima/modules/modeling_sparsellama.py", line 873, in forward
outputs = self.model(
File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/MiniMA/minima/modules/modeling_sparsellama.py", line 689, in forward
layer_outputs = decoder_layer(
File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/MiniMA/minima/modules/modeling_sparsellama.py", line 464, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/MiniMA/minima/modules/modeling_sparsellama.py", line 378, in forward
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacty of 79.15 GiB of which 1.79 GiB is free. Process 2805463 has 77.34 GiB memory in use. Of the allocated memory 76.71 GiB is allocated by PyTorch, and 139.70 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
How much memory does this need? In the tutorial, you mentioned it's possible to prune with 1xA100.
That is strange. I exactly use one A100-80G for sparsification. Are you sparsifying larger llama, say llama-13b?
No, just llama2 7b
Then could you please provide other details, e.g., hyperparameters?
What hyperparameters are you referring to? The commands run are copy pasted from the Tutorial page in this repo.
My dataset is 1.5GB (dataset used during pruning).
I'm trying to sparsify an already fine-tuned llama2 model, I don't believe that will make a difference in terms of memory usage?
For example, are you using a larger sequence length? I used a relatively short sequence length during pruning, i.e., 512.
These are the two commands I ran:
python run_building_data_llama.py \
--input_dir "/root/Layla-datasets/datasets_formatted/layla" \
--input_regex "*.txt" \
--output_dir /root/output/layla_tfrecords \
--tokenizer_name_or_path /root/llama2-7b-layla \
--do_lower_case \
--max_seq_length 4096 \
--num_processors 32
python run_sparsification_llama.py \
--model_type sparsellama_lm \
--teacher_model_name_or_path /root/llama2-7b-layla \
--record_path_or_regex "/root/output/layla_tfrecords/*.tfrecord" \
--data_type llama_lm \
--output_dir output/distilled_layla \
--max_length 512 \
--per_device_eval_batch_size 2 \
--model_suffix 7b
I can try with smaller pruning length? Will that affect quality?
Oh, I am awfully sorry that I did not give the tip that the data for pruning should be rebuilded with --max_seq_length 512
instead so that it could be used for pruning.
Besides, we did not test that much whether the pruning length will affect the quality. Maybe we could examine that later.
No problem, thanks for helping me. I will try again with less sequence length. Perhaps you should update your tutorial page with the correct seq_length as well.
Yes, I shall update it later.
If you further encounter any questions, please let me know.
The sparsification is running now, thank you for the help!
I'm wondering how long does it take on a A100? It's running for the whole night (about 10 hours) now. Wondering if it's stuck or it's supposed to take that long
It should take a long time, in my case, 1GB data would take more than 1 day to go : <
Ran sparsification for 1.5 days, got this error:
***** Running sparsification (w. sanity check) *****
Traceback (most recent call last):
File "/root/MiniMA/minima/run_sparsification_llama.py", line 230, in <module>
main()
File "/root/MiniMA/minima/run_sparsification_llama.py", line 145, in main
assert torch.isnan(loss) == False, "Loss is NaN!"
AssertionError: Loss is NaN!
Basically, I write the NaN loss detection in case of any loss spiking, which may potentially result in unexpected behavior in pruning. However, I have not encountered this issue during pruning in my experiments. The issue in your case perhaps is correlated with your data, so I suggest adding a if-nan-then-continue logic to skip the data. Or directly using float32 instead of float16.
BTW, I am trying to integrate deepspeed and flash attention into the pruning process so that you could a achieve higher speed : /
Hi @l3utterfly
I have updated the pruning process with deepspeed and flash attention, which largely reduces the compute from >1 days to several hours.
Hope you will find it useful!
https://github.com/GeneZC/MiniMA/commit/eb91dedd042f2d7e8323c6ca8377d8789b8f563e
Thank you so much! I will start a new run tonight!
I tried to distill with the new code. I noticed the pruning process now recommends 8 A100 GPUs?
I am using 1 A100 GPU and it still runs out of memory with my 1GB dataset. Lower numbers of GPU should only affect time taken and not memory usage right?
With fewer GPUs, maybe the batch size should also be decreased accordingly. Since deepspeed will permit a larger batch size per GPU when more GPUs are used.
I reduced the per_device_eval_batch_size
to 1.
Also, what should be the values here? --nnodes=$NODE_WORLD_SIZE --node_rank=$NODE_RANK
?
In your case, the two parameters should be removed. And I am not quite sure whether deepspeed would work for 1 GPU or not
Should I remove the deepspeed argument then?
Not really, deepspeed is integrated for speedup, and removing it will result in errors... Perhaps you could have a try on 2 GPUs or so.
I don't have access to 2 GPUs sadly. Are you using the 40GB A100 or the 80GB version?
80gb version.
I am closing this issue since it is not active, feel free to reopen it as you like.
Trying with the llama2 base weights.
I get the following error:
After hardcoding
use_cache=False
, and continuing, I get the following error:Can you help please?
Also:
from modules.fused_rope_monkey_patch_llama import apply_rotary_pos_emb
this seems to be a wrong import?Should it be:
from modules.modeling_llama import apply_rotary_pos_emb
in the file:flash_attn_monkey_patch_llama.py: line 10