fastnlp / TENER

Codes for "TENER: Adapting Transformer Encoder for Named Entity Recognition"
370 stars 55 forks source link

请教如何理解函数_transpose_shift #34

Open yangshuodelove opened 2 years ago

yangshuodelove commented 2 years ago

请教relative_transformer.py中_transpose_shift函数中的几个问题: (1)它实现了矩阵的什么变换(如移动、旋转等)? (2)怎么理解它是如何实现这种变换的呢? (3)倒数第3行中的indice为什么只选取奇数行呢? (4)_transpose_shift函数与_shift有什么区别,又有什么联系? 十分感谢


def _transpose_shift(self, E): 
       #E=[B,N,L,2*L]=[bsz, head, max_len, 2max_len] 如[2, 4, 68, 136];
        bsz, n_head, max_len, _ = E.size()
        zero_pad = E.new_zeros(bsz, n_head, max_len, 1)
        E = torch.cat([E, zero_pad], dim=-1).view(bsz, n_head, -1, max_len) # [B,N,2L,L] 
        indice = (torch.arange(max_len) * 2 + 1).to(E.device) # 选取是奇数行:[1,3,5...135]
        E = E.index_select(index=indice, dim=-2).transpose(-1, -2)  
        return E
yhcc commented 2 years ago

在对应函数的开头,我们都放置了一个例子,你可以对照着看看。 https://github.com/fastnlp/TENER/blob/d2614d509dffb9b30636e3523a2f8f0dc4876708/modules/relative_transformer.py#L164 如果想弄清楚每一步在做什么的话,建议可以初始化一个例子的矩阵,并把bsz和head的维度都设置为1,然后打印每一步的输出,大概就能知道每一步的效果是怎样了。