Closed mhillebrand closed 1 month ago
When doing inference on Gemma-2-2B with Flash Attention 2, I get the following error. It works just fine with Flash Attention disabled.
transformers==4.44.0 torch==2.4.0 flash-attn==2.6.3 python==3.12.4
nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Fri_Jun_14_16:34:21_PDT_2024 Cuda compilation tools, release 12.6, V12.6.20 Build cuda_12.6.r12.6/compiler.34431801_0
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [32,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [33,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [34,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [35,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [36,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [37,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [38,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [39,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [40,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [41,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [42,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [43,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [44,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [45,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [46,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [47,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [48,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [49,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [50,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [51,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [52,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [53,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [54,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [55,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [56,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [57,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [58,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [59,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [60,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [61,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [62,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. ../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [42,0,0], thread: [63,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed. Traceback (most recent call last): File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/mr/engine/utilz/common.py", line 244, in time_closure result = func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/pipelines/text_generation.py", line 262, in __call__ return super().__call__(text_inputs, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1257, in __call__ return self.run_single(inputs, preprocess_params, forward_params, postprocess_params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1264, in run_single model_outputs = self.forward(model_inputs, **forward_params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/pipelines/base.py", line 1164, in forward model_outputs = self._forward(model_inputs, **forward_params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/pipelines/text_generation.py", line 351, in _forward generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/generation/utils.py", line 2024, in generate result = self._sample( ^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/generation/utils.py", line 2982, in _sample outputs = self(**model_inputs, return_dict=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 999, in forward outputs = self.model( ^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 847, in forward layer_outputs = decoder_layer( ^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 590, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( ^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 423, in forward attn_output = _flash_attention_forward( ^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 246, in _flash_attention_forward query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( ^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/transformers/modeling_flash_attention_utils.py", line 102, in _upad_input key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/torch/autograd/function.py", line 574, in apply return super().apply(*args, **kwargs) # type: ignore[misc] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/matt/miniconda3/envs/infer/lib/python3.12/site-packages/flash_attn/bert_padding.py", line 17, in forward return torch.gather( ^^^^^^^^^^^^^ RuntimeError: CUDA error: device-side assert triggered Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
I encountered the same issue as you.
This is no longer an issue with Transformers v4.44.1
When doing inference on Gemma-2-2B with Flash Attention 2, I get the following error. It works just fine with Flash Attention disabled.
transformers==4.44.0 torch==2.4.0 flash-attn==2.6.3 python==3.12.4
nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Fri_Jun_14_16:34:21_PDT_2024 Cuda compilation tools, release 12.6, V12.6.20 Build cuda_12.6.r12.6/compiler.34431801_0