Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.39k stars 1.22k forks source link

The problem of dtype. #679

Open bxrjmfh opened 10 months ago

bxrjmfh commented 10 months ago

Prob and some fix

I'm using flash_attn==2.3.3 to load my finetuned LLaMa2 model (13B), but get an error when using the Flash_attn. In /flash_attn/bert_padding.py#L41 there is an error : IndexError: tensors used as indices must be long, byte or bool tensor , when I'm using generate method of model. I fixed the problem adding ".long()" of the indices.:

class IndexPutFirstAxis(torch.autograd.Function):
    @staticmethod
    def forward(ctx, values, indices, first_axis_dim):
        ctx.save_for_backward(indices)
        assert indices.ndim == 1
        assert values.ndim >= 2
        output = torch.zeros(
            first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
        )
        # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
        output[indices.long()] = values
        ######## add long()
        # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        (indices,) = ctx.saved_tensors
        # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
        grad_values = grad_output[indices.long()]
        ######## add long()
        # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
        return grad_values, None, None

Warning message

Here are full warning message:


IndexError Traceback (most recent call last) /root/VLN_2023/MiC/test_accelerate_llama.ipynb Cell 5 line 5 3 inputs = tokenizer(batch, return_tensors="pt", padding=True).to(device) 4 # Generate ----> 5 generate_ids = model.generate(inputs.input_ids, max_length=500) 6 res = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) 7 ress.append(res)

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.call..decorate_context(*args, kwargs) 24 @functools.wraps(func) 25 def decorate_context(*args, *kwargs): 26 with self.clone(): ---> 27 return func(args, kwargs)

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/transformers/generation/utils.py:1673, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, kwargs) 1656 return self.assisted_decoding( 1657 input_ids, 1658 assistant_model=assistant_model, (...) 1669 model_kwargs, 1670 ) 1671 if generation_mode == GenerationMode.GREEDY_SEARCH: 1672 # 11. run greedy search -> 1673 return self.greedy_search( 1674 input_ids, 1675 logits_processor=logits_processor, 1676 stopping_criteria=stopping_criteria, 1677 pad_token_id=generation_config.pad_token_id, 1678 eos_token_id=generation_config.eos_token_id, 1679 output_scores=generation_config.output_scores, 1680 return_dict_in_generate=generation_config.return_dict_in_generate, 1681 synced_gpus=synced_gpus, 1682 streamer=streamer, 1683 **model_kwargs, 1684 ) 1686 elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH: 1687 if not model_kwargs["use_cache"]:

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/transformers/generation/utils.py:2521, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, model_kwargs) 2518 model_inputs = self.prepare_inputs_for_generation(input_ids, model_kwargs) 2520 # forward pass to get next token -> 2521 outputs = self( 2522 **model_inputs, 2523 return_dict=True, 2524 output_attentions=output_attentions, 2525 output_hidden_states=output_hidden_states, 2526 ) 2528 if synced_gpus and this_peer_finished: 2529 continue # don't waste resources running the code we don't need

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, *kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:1034, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict) 1031 return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1033 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) -> 1034 outputs = self.model( 1035 input_ids=input_ids, 1036 attention_mask=attention_mask, 1037 position_ids=position_ids, 1038 past_key_values=past_key_values, 1039 inputs_embeds=inputs_embeds, 1040 use_cache=use_cache, 1041 output_attentions=output_attentions, 1042 output_hidden_states=output_hidden_states, 1043 return_dict=return_dict, 1044 ) 1046 hidden_states = outputs[0] 1047 if self.config.pretraining_tp > 1:

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, *kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:922, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) 912 layer_outputs = self._gradient_checkpointing_func( 913 decoder_layer.call, 914 hidden_states, (...) 919 use_cache, 920 ) 921 else: --> 922 layer_outputs = decoder_layer( 923 hidden_states, 924 attention_mask=attention_mask, 925 position_ids=position_ids, 926 past_key_value=past_key_value, 927 output_attentions=output_attentions, 928 use_cache=use_cache, 929 ) 931 hidden_states = layer_outputs[0] 933 if use_cache:

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, *kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:672, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, kwargs) 669 hidden_states = self.input_layernorm(hidden_states) 671 # Self Attention --> 672 hidden_states, self_attn_weights, present_key_value = self.self_attn( 673 hidden_states=hidden_states, 674 attention_mask=attention_mask, 675 position_ids=position_ids, 676 past_key_value=past_key_value, 677 output_attentions=output_attentions, 678 use_cache=use_cache, 679 kwargs, 680 ) 681 hidden_states = residual + hidden_states 683 # Fully Connected

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, *kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:521, in LlamaFlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs) 518 key_states = key_states.to(target_dtype) 519 value_states = value_states.to(target_dtype) --> 521 attn_output = self._flash_attention_forward( 522 query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate 523 ) 525 attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 526 attn_output = self.o_proj(attn_output)

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py:578, in LlamaFlashAttention2._flash_attention_forward(self, query_states, key_states, value_states, attention_mask, query_length, dropout, softmax_scale) 563 max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 565 attn_output_unpad = flash_attn_varlen_func( 566 query_states, 567 key_states, (...) 575 causal=self.is_causal, 576 ) --> 578 attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 579 else: 580 attn_output = flash_attn_func( 581 query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal 582 )

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/flash_attn/bert_padding.py:208, in pad_input(hidden_states, indices, batch, seqlen) 205 dim = hidden_states.shape[-1] 206 # output = torch.zeros((batch seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) 207 # output[indices] = hidden_states --> 208 output = index_put_first_axis(hidden_states, indices, batch seqlen) 209 return rearrange(output, "(b s) ... -> b s ...", b=batch)

File ~/miniconda3/envs/duet/lib/python3.8/site-packages/flash_attn/bert_padding.py:51, in IndexPutFirstAxis.forward(ctx, values, indices, first_axis_dim) 47 output = torch.zeros( 48 first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype 49 ) 50 # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. ---> 51 output[indices] = values 52 # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) 53 return output

IndexError: tensors used as indices must be long, byte or bool tensors

My code pieces


import json
import torch
device = 'cuda:1'
from transformers import AutoTokenizer, LlamaForCausalLM
PATH_TO_CONVERTED_WEIGHTS = '/root/VLN_2023/llama2/llama2_hf/hf_ckpt_r2r_v3'
dtype = torch.float16
model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS,torch_dtype=dtype,device_map=device,use_flash_attention_2 = True)
tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)

hint_data = []
recored_path = '/root/VLN_2023/HM3DAutoVLN/records/20231113-115042(p3)(fix_all)/hint_record/hint_history.jsonl'
with open(recored_path, "r") as file:
    for line in file:
        # 解析每一行的 JSON 对象
        data = json.loads(line)
        data = [x[3] for x in data]
        hint_data.append( data)
# log.append([sub_inss[counter],fmt_inss[counter]])
# hint_rec = [[obs[i]['instr_id'],
#             obs[i]['scan'],
#             obs[i]['viewpoint'],
#             obs[i]['prompt'],
#             log[i]] for i,flag in enumerate(help_points) if flag == True]

tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_WEIGHTS,padding_size="left")
if '<pad>' not in tokenizer.get_vocab():
    # Add the pad token
    tokenizer.add_special_tokens({"pad_token":"<apd>"})

#Resize the embeddings
model.resize_token_embeddings(len(tokenizer))

#Configure the pad token in the model
model.config.pad_token_id = tokenizer.pad_token_id

# Check if they are equal
assert model.config.pad_token_id == tokenizer.pad_token_id, "The model's pad token ID does not match the tokenizer's pad token ID!"

ress = []
for batch in hint_data:
    inputs = tokenizer(batch, return_tensors="pt", padding=True).to(device)
    # Generate
    generate_ids = model.generate(inputs.input_ids, max_new_tokens=15)
    res = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    ress.append(res)
bxrjmfh commented 10 months ago

I'd like to supply any detailed infomation as I can. Hope this can help.