gaohongkui / GlobalPointer_pytorch

全局指针统一处理嵌套与非嵌套NER的Pytorch实现
380 stars 45 forks source link

split的维度问题 #16

Closed Alwin4Zhang closed 2 years ago

Alwin4Zhang commented 2 years ago

苏建林的tf原版

def __init__(
        self,
        heads,
        head_size,
        RoPE=True,
        use_bias=True,
        kernel_initializer='glorot_uniform',
        **kwargs
    ):
        super(GlobalPointer, self).__init__(**kwargs)
        self.heads = heads
        self.head_size = head_size
        ...

def call(self, inputs, mask=None):
        # 输入变换
        inputs = self.dense(inputs)
        inputs = tf.split(inputs, self.heads, axis=-1)
        ...

您的版本:

def __init__(self, encoder, ent_type_size, inner_dim, RoPE=True):
    super().__init__()
    self.encoder = encoder
    self.ent_type_size = ent_type_size # 实体类型个数
    self.inner_dim = inner_dim # head_size??? head的维度大小???
    self.hidden_size = encoder.config.hidden_size
    self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2)
    ......

 def forward(self, input_ids, attention_mask, token_type_ids):
      .......
      outputs = self.dense(last_hidden_state)
      outputs = torch.split(outputs, self.inner_dim * 2, dim=-1)

按照苏建林版本我的理解是,head_size表示head头的大小,heads是head的个数,也就是实体类型的个数;在下面的split时按照实体类型的维度展开; 您的版本中torch.split按照head_size * 2的展开,这里是我理解的有问题还是有错误?麻烦指点,谢谢!