Infini-AI-Lab / TriForce

[COLM 2024] TriForce: Lossless Acceleration of Long Sequence Generation with Hierarchical Speculative Decoding
https://infini-ai-lab.github.io/TriForce/
231 stars 12 forks source link

Attention Scores Matrix Visualization #10

Open bulaikexiansheng opened 3 months ago

bulaikexiansheng commented 3 months ago

Hi, I would like to ask why the attention mask is not used in the prefill stage. I want to output the attention scores matrix in prefill stage. Is the code below right?

        if spec: # spec decoding
            key_states, value_states = graph_cache.update(new_k_cache=key_states, new_v_cache=value_states, layer_idx=self.layer_idx)
        else:
            # update kv cache first
            key_states, value_states = kv_cache.update(key_states, value_states, layer_idx=self.layer_idx)
            if query_states.shape[1] == 1 and (isinstance(graph_cache, RetrievalCache)): 
                if graph_cache.init_graph == False:
                    # init graph cache
                    graph_cache.init_graph_cache(kv_cache, query_states, self.layer_idx)
                else:
                    # update graph cache (customized)
                    graph_cache.update_graph_cache_retrieval(kv_cache, query_states, self.layer_idx)

        # 计算注意力得分矩阵
        attention_scores = torch.einsum("bqhd,bkhd->bhqk", query_states, key_states)
        attention_scores /= torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if attention_mask is not None:
            attention_mask = attention_mask.to(attention_scores.device)
            attention_scores += attention_mask

        attn_output = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, softmax_scale=1/torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float16)), causal=True)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output, attention_scores
preminstrel commented 3 months ago

Hello,

We use flash attention function which already has causal mask for prefilling phase.

It should be noted that it is easy to have OOM issue when you are trying to compute attention matrix directly for long sequences.