princeton-nlp / ProLong

Homepage for ProLong (Princeton long-context language models) and paper "How to Train Long-Context Language Models (Effectively)"
MIT License
102 stars 1 forks source link

[Bug] A tensor index bug in LLama Attention #3

Closed Chenfeng1271 closed 1 week ago

Chenfeng1271 commented 1 week ago

Hi, thank you for your excellent work. Currently, I am developing a efficient long-context training method based on this code. But when I manipulate the query in LLaMA attention using a simple index like q[0:,], it raises a bug...

`(Pdb) q[0]

input_ids: [806, 13, 17, 13642, 320, 7261, 4364, 2137, 5943, 8, 2355, 197, 9391, 5778, 25, 3851, 43801, 92231, 7502, 12861, 3933, 41779, 29148, 1326, 10169, 198, 197, 12, 197, 63, 15605, 4146, 25, 24, 67, 21144, 2371, 2042, 2 Traceback (most recent call last): File "/usr/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/usr/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/mnt/bn/tns-algo-public-my2/zhangjiyuan/chenfeng/code/longvideollm/ProLong/training/train_language_model.py", line 256, in main() File "/mnt/bn/tns-algo-public-my2/zhangjiyuan/chenfeng/code/longvideollm/ProLong/training/train_language_model.py", line 225, in main train_result = trainer.train(resume_from_checkpoint=checkpoint) File "/mnt/bn/tns-algo-public-my2/zhangjiyuan/chenfeng/code/longvideollm/ProLong/training/trainer.py", line 1028, in train return inner_training_loop( File "/home/tiger/.local/lib/python3.9/site-packages/transformers/trainer.py", line 2388, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/home/tiger/.local/lib/python3.9/site-packages/transformers/trainer.py", line 3485, in training_step loss = self.compute_loss(model, inputs) File "/mnt/bn/tns-algo-public-my2/zhangjiyuan/chenfeng/code/longvideollm/ProLong/training/trainer.py", line 342, in compute_loss outputs = model(inputs, use_cache=False) File "/mnt/bn/tns-algo-public-my2/zhangjiyuan/chenfeng/code/longvideollm/ProLong/training/trainer.py", line 342, in compute_loss outputs = model(inputs, use_cache=False) File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, *kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/accelerate/utils/operations.py", line 820, in forward return model_forward(args, kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/accelerate/utils/operations.py", line 808, in call return convert_to_fp32(self.model_forward(*args, kwargs)) File "/home/tiger/.local/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast return func(*args, *kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward output = self._fsdp_wrapped_module(args, kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, *kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/accelerate/utils/operations.py", line 820, in forward return model_forward(args, kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/accelerate/utils/operations.py", line 808, in call return convert_to_fp32(self.model_forward(*args, kwargs)) File "/home/tiger/.local/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast return func(*args, *kwargs) File "/mnt/bn/tns-algo-public-my2/zhangjiyuan/chenfeng/code/longvideollm/ProLong/training/modeling_flash_llama.py", line 889, in forward outputs = self.model( File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, kwargs) File "/mnt/bn/tns-algo-public-my2/zhangjiyuan/chenfeng/code/longvideollm/ProLong/training/modeling_flash_llama.py", line 718, in forward layer_outputs = torch.utils.checkpoint.checkpoint( File "/home/tiger/.local/lib/python3.9/site-packages/torch/_compile.py", line 31, in inner return disable_fn(*args, *kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn return fn(args, kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 488, in checkpoint ret = function(*args, kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(args, kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward output = self._fsdp_wrapped_module(*args, kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(args, kwargs) File "/mnt/bn/tns-algo-public-my2/zhangjiyuan/chenfeng/code/longvideollm/ProLong/training/modeling_flash_llama.py", line 589, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/tiger/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(args, **kwargs) File "/mnt/bn/tns-algo-public-my2/zhangjiyuan/chenfeng/code/longvideollm/ProLong/training/modeling_flash_llama.py", line 471, in forward probe_query_states = q[:,probe_ids,:]#.detach() File "/mnt/bn/tns-algo-public-my2/zhangjiyuan/chenfeng/code/longvideollm/ProLong/training/modeling_flash_llama.py", line 471, in forward probe_query_states = q[:,probe_ids,:]#.detach() File "/usr/lib/python3.9/bdb.py", line 88, in trace_dispatch return self.dispatch_line(frame) File "/usr/lib/python3.9/bdb.py", line 113, in dispatch_line if self.quitting: raise BdbQuit`

The most interesting thing is I can index k and v. Do you know how to solve this bug? Thank you very much.

Chenfeng1271 commented 1 week ago

I use two H100 80G and pdb for debug. I use nvidia-smi and watch the GPU memory is fine. Not just for the simple index manipulation, I also can not print the size of q..