openlm-research / open_llama

OpenLLaMA, a permissively licensed open source reproduction of Meta AI’s LLaMA 7B trained on the RedPajama dataset
Apache License 2.0
7.36k stars 374 forks source link

Why do you only use causal_mask (xops.LowerTriangularMask()) when use_memory_efficient_attention is True? #87

Closed EliverQ closed 1 year ago

EliverQ commented 1 year ago

Awesome work! Thank your for your opening checkpoints and codes. However, I have some questions here. In modeling_open_llama.py:

        if self.config.use_memory_efficient_attention and xops is not None and self.training:
            attn_weights = None
            query_states = query_states.transpose(1, 2)
            key_states = key_states.transpose(1, 2)
            value_states = value_states.transpose(1, 2)
            attn_output = xops.memory_efficient_attention(
                query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask(), p=self.dropout_prob
            )

https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/models/open_llama/modeling_open_llama.py#L297-L304 Why you just use xops.LowerTriangularMask() here? It seems that the model will pay attention to the padding tokens? I'll appreciate it if you could help me answer.

young-geng commented 1 year ago

The Huggingface Open-LLaMA implementation is not related to this project. Our OpenLLaMA uses the standard LLaMA implementation in Huggingface