Hi,
I am trying to fine-tune a 7b model for 16k context length on a 8 GPU, A100, 40 GB machine. But, I am getting the following runtime error:
Traceback (most recent call last):
File "/home/ec2-user/data/yarn/finetune.py", line 222, in <module>
main(args.parse_args())
File "/home/ec2-user/data/yarn/finetune.py", line 150, in main
loss = model(**batch).loss
File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 1801, in forward
loss = self.module(*inputs, **kwargs)
File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/ec2-user/data/yarn/scaled_rope/modeling_llama_together_yarn.py", line 985, in forward
outputs = self.model(
File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/ec2-user/data/yarn/scaled_rope/modeling_llama_together_yarn.py", line 860, in forward
layer_outputs = torch.utils.checkpoint.checkpoint(
File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 249, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 107, in forward
outputs = run_function(*args)
File "/home/ec2-user/data/yarn/scaled_rope/modeling_llama_together_yarn.py", line 856, in custom_forward
return module(*inputs, output_attentions, None)
File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/ec2-user/data/yarn/scaled_rope/modeling_llama_together_yarn.py", line 620, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1538, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/ec2-user/data/yarn/scaled_rope/modeling_llama_together_yarn.py", line 555, in forward
).reshape(bsz, q_len, h_size)
RuntimeError: shape '[1, 16384, 4096]' is invalid for input of size 13459456
Hi, I am trying to fine-tune a 7b model for 16k context length on a 8 GPU, A100, 40 GB machine. But, I am getting the following runtime error:
Here is the command:
accelerate launch finetune.py --wandb yarn --output-dir output/yarn-7b-16k --model meta-llama/Llama-2-7b-chat-hf --max-train-steps 20 --scaling-factor 4 --scaling-type yarn --seed 31337 --dataset shossain/govreport-qa-5-16384 --gradient-accumulate-every 1
Please suggest.