lancopku / label-words-are-anchors

Repository for Label Words are Anchors: An Information Flow Perspective for Understanding In-Context Learning
MIT License
144 stars 12 forks source link

关于Figure 1作图的疑问以及方法能否适用于Bert、Roberta类模型的疑问 #8

Open lczx666 opened 9 months ago

lczx666 commented 9 months ago

拜读了诸位在EMNLP 2023的best paper,对于文中的Figure 1清晰的表达方式十分感兴趣,请问作者是否打算公开这里的作图代码呢?此外我想测试bert、roberta类型的PLM的相关指标,请问作者的代码是否支持呢?

leanwang326 commented 9 months ago

这个用bertviz画的,就是算好了saliency(这个我们代码里有,就是attention_attr.py的saliency = attentionermanger.grad(use_abs=True)[i])以后,用bertviz画就行(黄色的部分是我们为了强调手动加上的颜色,不是代码画的)

leanwang326 commented 9 months ago

bert之类的我们的代码倒是也能跑,原则上没什么问题,但是值得注意的是他们的是双向注意力,不是gpt一样的单向的,所以我们的指标可能也许得要做点调整。

光是为了跑起来的话,可能得类似GPT2AttentionerManager写个BertAttentionerManager(参照Does this apply to the llama2 model这个issue)。

lczx666 commented 8 months ago

bert之类的我们的代码倒是也能跑,原则上没什么问题,但是值得注意的是他们的是双向注意力,不是gpt一样的单向的,所以我们的指标可能也许得要做点调整。

光是为了跑起来的话,可能得类似GPT2AttentionerManager写个BertAttentionerManager(参照Does this apply to the llama2 modelthis issue)。

非常感谢作者的回复,事实上我在您回复后就开始看您的源代码,但遗憾的是目前并没有弄懂具体应该怎么写BertAttentionerManager这个函数,请问作者您之后有考虑扩充Bert和Robert类模型的相关函数吗?或者说目前的代码如果仅仅是跑通是否可以有更简洁的方式呢?期待您的回复

leanwang326 commented 8 months ago
def bert_attn(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.FloatTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.FloatTensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
    output_attentions: Optional[bool] = False,
    attention_adapter=None
) -> Tuple[torch.Tensor]:
    mixed_query_layer = self.query(hidden_states)

    # If this is instantiated as a cross-attention module, the keys
    # and values come from an encoder; the attention mask needs to be
    # such that the encoder's padding tokens are not attended to.
    is_cross_attention = encoder_hidden_states is not None

    if is_cross_attention and past_key_value is not None:
        # reuse k,v, cross_attentions
        key_layer = past_key_value[0]
        value_layer = past_key_value[1]
        attention_mask = encoder_attention_mask
    elif is_cross_attention:
        key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
        value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
        attention_mask = encoder_attention_mask
    elif past_key_value is not None:
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
        value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
    else:
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

    query_layer = self.transpose_for_scores(mixed_query_layer)

    use_cache = past_key_value is not None
    if self.is_decoder:
        # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
        # Further calls to cross_attention layer can then reuse all cross-attention
        # key/value_states (first "if" case)
        # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
        # all previous decoder key/value_states. Further calls to uni-directional self-attention
        # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
        # if encoder bi-directional self-attention `past_key_value` is always `None`
        past_key_value = (key_layer, value_layer)

    # Take the dot product between "query" and "key" to get the raw attention scores.
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
        query_length, key_length = query_layer.shape[2], key_layer.shape[2]
        if use_cache:
            position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
                -1, 1
            )
        else:
            position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
        position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
        distance = position_ids_l - position_ids_r

        positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
        positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

        if self.position_embedding_type == "relative_key":
            relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores
        elif self.position_embedding_type == "relative_key_query":
            relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

    attention_scores = attention_scores / math.sqrt(self.attention_head_size)

    if attention_mask is not None:
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        attention_scores = attention_scores + attention_mask

    # Normalize the attention scores to probabilities.
    attention_probs = nn.functional.softmax(attention_scores, dim=-1)

    # ==========================================
    # add this
    # ==========================================
    if attention_adapter is not None:
        attention_probs = attention_adapter(attention_probs)
    # ==========================================
    # add this
    # ==========================================

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.dropout(attention_probs)

    # Mask heads if we want to
    if head_mask is not None:
        attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(new_context_layer_shape)

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

    if self.is_decoder:
        outputs = outputs + (past_key_value,)
    return outputs

class BertAttentionerManager(AttentionerManagerBase):
    def __init__(self, model: PreTrainedModel):
        super().__init__(model)

    def register_attentioner_to_model(self):
        attention_adapters = []
        for i, layer in enumerate(self.bert.encoder.layer):
            attention_adapter = AttentionAdapter()
            layer.attention.self = partial(bert_attn, layer.attention.self,
                                       attention_adapter=attention_adapter)
            attention_adapters.append(attention_adapter)
        return attention_adapters
leanwang326 commented 8 months ago

这个就是为了获取关于attention矩阵的导数,也可以直接改bert的代码,我这边为了避免改代码就直接替换了bert里面的算attention的函数来截取这个。 你也可以直接改bert的代码,以及说采用torch的hook的办法,给attention_probs加个hook也成,这样应该方便不少

lzy37ld commented 5 months ago

@leanwang326 非常inspiring的文章!想问一下为什么会考虑前面用的是saliency score (before sec2.3)后面转而用attention score(sec2.3)呢。

edit: 问题来源看到这个以后想到的(https://aclanthology.org/2020.blackboxnlp-1.14.pdf

leanwang326 commented 5 months ago

@leanwang326 非常inspiring的文章!想问一下为什么会考虑前面用的是saliency score (before sec2.3)后面转而用attention score(sec2.3)呢。

edit: 问题来源看到这个以后想到的(https://aclanthology.org/2020.blackboxnlp-1.14.pdf)

其实我们相当于都用了,从两个方面佐证了一下(

lzy37ld commented 5 months ago

@leanwang326 非常inspiring的文章!想问一下为什么会考虑前面用的是saliency score (before sec2.3)后面转而用attention score(sec2.3)呢。 edit: 问题来源看到这个以后想到的(https://aclanthology.org/2020.blackboxnlp-1.14.pdf)

其实我们相当于都用了,从两个方面佐证了一下(

@leanwang326 感谢回复! 那所以可以认为他们是interchangable的吗?(saliency score和attentionscore 某种程度上是等价的?) 还是说那个saliency的图是一开始是就是用saliency画的 效果不错就没有考虑再用attention了

leanwang326 commented 5 months ago

@leanwang326 非常inspiring的文章!想问一下为什么会考虑前面用的是saliency score (before sec2.3)后面转而用attention score(sec2.3)呢。 edit: 问题来源看到这个以后想到的(https://aclanthology.org/2020.blackboxnlp-1.14.pdf)

其实我们相当于都用了,从两个方面佐证了一下(

@leanwang326 感谢回复! 那所以可以认为他们是interchangable的吗?(saliency score和attentionscore 某种程度上是等价的?) 还是说那个saliency的图是一开始是就是用saliency画的 效果不错就没有考虑再用attention了

好像有时候attention会不靠谱一点,所以我两个都看了一下,然后那个图的话确实是因为saliency画的不错,也就没考虑用attention再画一次