GeneZC / MiniMA

Code for paper titled "Towards the Law of Capacity Gap in Distilling Language Models"
Apache License 2.0
91 stars 5 forks source link

Getting errors when trying to replicate the distilling operation #2

Closed l3utterfly closed 6 months ago

l3utterfly commented 7 months ago

Trying with the llama2 base weights.

I get the following error:

File "/root/MiniMA/minima/modules/flash_attn_monkey_patch_sparsellama.py", line 47, in forward
    assert not use_cache, "use_cache is not supported"

After hardcoding use_cache=False, and continuing, I get the following error:

File "/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_varlen_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type

Can you help please?

Also: from modules.fused_rope_monkey_patch_llama import apply_rotary_pos_emb this seems to be a wrong import?

Should it be: from modules.modeling_llama import apply_rotary_pos_emb in the file: flash_attn_monkey_patch_llama.py: line 10

GeneZC commented 7 months ago
File "/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_varlen_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type

This error is primarily due to that flash attention does not support head mask. So we actually did not use flash attention during sparsification and then use it during distillation.

https://github.com/GeneZC/MiniMA/commit/01891a5827e3b14d9180127a6f824db37011e9fd

We have updated the codebase to not use flash attention for sparsification and you could have a try now.

l3utterfly commented 7 months ago

Thanks for the update. The sparsification now starts. However, I'm getting an out of memory error running on one A100 GPU with 80GB VRAM:

***** Running sparsification (w. sanity check) *****
Traceback (most recent call last):
  File "/root/MiniMA/minima/run_sparsification_llama.py", line 230, in <module>
    main()
  File "/root/MiniMA/minima/run_sparsification_llama.py", line 142, in main
    output = model(**batch, head_mask=head_mask, neuron_mask=neuron_mask, hidden_mask=hidden_mask)
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/MiniMA/minima/modules/modeling_sparsellama.py", line 873, in forward
    outputs = self.model(
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/MiniMA/minima/modules/modeling_sparsellama.py", line 689, in forward
    layer_outputs = decoder_layer(
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/MiniMA/minima/modules/modeling_sparsellama.py", line 464, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/MiniMA/minima/modules/modeling_sparsellama.py", line 378, in forward
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacty of 79.15 GiB of which 1.79 GiB is free. Process 2805463 has 77.34 GiB memory in use. Of the allocated memory 76.71 GiB is allocated by PyTorch, and 139.70 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

How much memory does this need? In the tutorial, you mentioned it's possible to prune with 1xA100.

GeneZC commented 7 months ago

That is strange. I exactly use one A100-80G for sparsification. Are you sparsifying larger llama, say llama-13b?

l3utterfly commented 7 months ago

No, just llama2 7b

GeneZC commented 7 months ago

Then could you please provide other details, e.g., hyperparameters?

l3utterfly commented 7 months ago

What hyperparameters are you referring to? The commands run are copy pasted from the Tutorial page in this repo.

My dataset is 1.5GB (dataset used during pruning).

I'm trying to sparsify an already fine-tuned llama2 model, I don't believe that will make a difference in terms of memory usage?

GeneZC commented 7 months ago

For example, are you using a larger sequence length? I used a relatively short sequence length during pruning, i.e., 512.

l3utterfly commented 7 months ago

These are the two commands I ran:

python run_building_data_llama.py \
    --input_dir "/root/Layla-datasets/datasets_formatted/layla" \
    --input_regex "*.txt" \
    --output_dir /root/output/layla_tfrecords \
    --tokenizer_name_or_path /root/llama2-7b-layla \
    --do_lower_case \
    --max_seq_length 4096 \
    --num_processors 32

python run_sparsification_llama.py \
    --model_type sparsellama_lm \
    --teacher_model_name_or_path /root/llama2-7b-layla \
    --record_path_or_regex "/root/output/layla_tfrecords/*.tfrecord" \
    --data_type llama_lm \
    --output_dir output/distilled_layla \
    --max_length 512 \
    --per_device_eval_batch_size 2 \
    --model_suffix 7b
l3utterfly commented 7 months ago

I can try with smaller pruning length? Will that affect quality?

GeneZC commented 7 months ago

Oh, I am awfully sorry that I did not give the tip that the data for pruning should be rebuilded with --max_seq_length 512 instead so that it could be used for pruning.

Besides, we did not test that much whether the pruning length will affect the quality. Maybe we could examine that later.

l3utterfly commented 7 months ago

No problem, thanks for helping me. I will try again with less sequence length. Perhaps you should update your tutorial page with the correct seq_length as well.

GeneZC commented 7 months ago

Yes, I shall update it later.

If you further encounter any questions, please let me know.

l3utterfly commented 7 months ago

The sparsification is running now, thank you for the help!

I'm wondering how long does it take on a A100? It's running for the whole night (about 10 hours) now. Wondering if it's stuck or it's supposed to take that long

GeneZC commented 7 months ago

It should take a long time, in my case, 1GB data would take more than 1 day to go : <

l3utterfly commented 7 months ago

Ran sparsification for 1.5 days, got this error:

***** Running sparsification (w. sanity check) *****
Traceback (most recent call last):
  File "/root/MiniMA/minima/run_sparsification_llama.py", line 230, in <module>
    main()
  File "/root/MiniMA/minima/run_sparsification_llama.py", line 145, in main
    assert torch.isnan(loss) == False, "Loss is NaN!"
AssertionError: Loss is NaN!
GeneZC commented 7 months ago

Basically, I write the NaN loss detection in case of any loss spiking, which may potentially result in unexpected behavior in pruning. However, I have not encountered this issue during pruning in my experiments. The issue in your case perhaps is correlated with your data, so I suggest adding a if-nan-then-continue logic to skip the data. Or directly using float32 instead of float16.

BTW, I am trying to integrate deepspeed and flash attention into the pruning process so that you could a achieve higher speed : /

GeneZC commented 7 months ago

Hi @l3utterfly

I have updated the pruning process with deepspeed and flash attention, which largely reduces the compute from >1 days to several hours.

Hope you will find it useful!

https://github.com/GeneZC/MiniMA/commit/eb91dedd042f2d7e8323c6ca8377d8789b8f563e

l3utterfly commented 7 months ago

Thank you so much! I will start a new run tonight!

l3utterfly commented 7 months ago

I tried to distill with the new code. I noticed the pruning process now recommends 8 A100 GPUs?

I am using 1 A100 GPU and it still runs out of memory with my 1GB dataset. Lower numbers of GPU should only affect time taken and not memory usage right?

GeneZC commented 7 months ago

With fewer GPUs, maybe the batch size should also be decreased accordingly. Since deepspeed will permit a larger batch size per GPU when more GPUs are used.

l3utterfly commented 7 months ago

I reduced the per_device_eval_batch_size to 1.

Also, what should be the values here? --nnodes=$NODE_WORLD_SIZE --node_rank=$NODE_RANK?

GeneZC commented 7 months ago

In your case, the two parameters should be removed. And I am not quite sure whether deepspeed would work for 1 GPU or not

l3utterfly commented 7 months ago

Should I remove the deepspeed argument then?

GeneZC commented 7 months ago

Not really, deepspeed is integrated for speedup, and removing it will result in errors... Perhaps you could have a try on 2 GPUs or so.

l3utterfly commented 7 months ago

I don't have access to 2 GPUs sadly. Are you using the 40GB A100 or the 80GB version?

GeneZC commented 7 months ago

80gb version.

GeneZC commented 6 months ago

I am closing this issue since it is not active, feel free to reopen it as you like.