jquesnelle / yarn

YaRN: Efficient Context Window Extension of Large Language Models
MIT License
1.25k stars 110 forks source link

Runtime error #31

Open shossain opened 9 months ago

shossain commented 9 months ago

Hi, I am trying to fine-tune a 7b model for 16k context length on a 8 GPU, A100, 40 GB machine. But, I am getting the following runtime error:

Traceback (most recent call last):
File "/home/ec2-user/data/yarn/finetune.py", line 222, in <module>
    main(args.parse_args())
  File "/home/ec2-user/data/yarn/finetune.py", line 150, in main
    loss = model(**batch).loss
  File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1801, in forward
    loss = self.module(*inputs, **kwargs)
  File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/ec2-user/data/yarn/scaled_rope/modeling_llama_together_yarn.py", line 985, in forward
    outputs = self.model(
  File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/ec2-user/data/yarn/scaled_rope/modeling_llama_together_yarn.py", line 860, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
    outputs = run_function(*args)
  File "/home/ec2-user/data/yarn/scaled_rope/modeling_llama_together_yarn.py", line 856, in custom_forward
    return module(*inputs, output_attentions, None)
  File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/ec2-user/data/yarn/scaled_rope/modeling_llama_together_yarn.py", line 620, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/home/ec2-user/data/yarn/scaled_rope/modeling_llama_together_yarn.py", line 555, in forward
    ).reshape(bsz, q_len, h_size)
RuntimeError: shape '[1, 16384, 4096]' is invalid for input of size 13459456

Here is the command:

accelerate launch finetune.py --wandb yarn --output-dir output/yarn-7b-16k --model meta-llama/Llama-2-7b-chat-hf --max-train-steps 20 --scaling-factor 4 --scaling-type yarn --seed 31337 --dataset shossain/govreport-qa-5-16384 --gradient-accumulate-every 1

Please suggest.