131250208 / TPlinker-joint-extraction

438 stars 94 forks source link

torch.repeat and for loop in forward step cost a lot of time #12

Closed JaheimLee closed 3 years ago

JaheimLee commented 3 years ago

Do you have any suggestions to speedup train step or inference step?

131250208 commented 3 years ago

You could change the for loop in the handshaking kernel to matrix operations. Try this code:

class HandshakingKernel(nn.Module):
    def __init__(self, hidden_size, shaking_type, only_look_after=True):
        super().__init__()
        self.shaking_type = shaking_type
        self.only_look_after = only_look_after

        if "cat" in shaking_type:
            self.cat_fc = nn.Linear(hidden_size * 2, hidden_size)
        if "cln" in shaking_type:
            self.tp_cln = LayerNorm(hidden_size, hidden_size, conditional=True)
        if "lstm" in shaking_type:
            assert only_look_after is True
            self.lstm4span = nn.LSTM(hidden_size,
                                     hidden_size,
                                     num_layers=1,
                                     bidirectional=False,
                                     batch_first=True)

    def forward(self, seq_hiddens):
        '''
        seq_hiddens: (batch_size, seq_len, hidden_size_x)
        return:
            if only look after:
                shaking_hiddenss: (batch_size, (1 + seq_len) * seq_len / 2, hidden_size); e.g. (32, 5+4+3+2+1, 5)
            else:
                shaking_hiddenss: (batch_size, seq_len * seq_len, hidden_size)
        '''
        seq_len = seq_hiddens.size()[1]

        guide = seq_hiddens[:, :, None, :].repeat(1, 1, seq_len, 1)
        visible = guide.permute(0, 2, 1, 3)

        shaking_pre = None

        # pre_num = 0

        def add_presentation(all_prst, prst):
            if all_prst is None:
                all_prst = prst
            else:
                all_prst += prst
            return all_prst

        if self.only_look_after:
            if "lstm" in self.shaking_type:
                batch_size, _, matrix_size, vis_hidden_size = visible.size()
                # mask lower triangle
                upper_visible = visible.permute(0, 3, 1, 2).triu().permute(0, 2, 3, 1).contiguous()

                # visible4lstm: (batch_size * matrix_size, matrix_size, hidden_size)
                visible4lstm = upper_visible.view(-1, matrix_size, vis_hidden_size)
                span_pre, _ = self.lstm4span(visible4lstm)
                span_pre = span_pre.view(batch_size, matrix_size, matrix_size, vis_hidden_size)

                # drop lower triangle and convert matrix to sequence
                # span_pre: (batch_size, shaking_seq_len, hidden_size)
                span_pre = MyMatrix.upper_reg2seq(span_pre)
                shaking_pre = add_presentation(shaking_pre, span_pre)

            # guide, visible: (batch_size, shaking_seq_len, hidden_size)
            guide = MyMatrix.upper_reg2seq(guide)
            visible = MyMatrix.upper_reg2seq(visible)

        if "cat" in self.shaking_type:
            tp_cat_pre = torch.cat([guide, visible], dim=-1)
            tp_cat_pre = torch.relu(self.cat_fc(tp_cat_pre))
            shaking_pre = add_presentation(shaking_pre, tp_cat_pre)

        if "cln" in self.shaking_type:
            tp_cln_pre = self.tp_cln(visible, guide)
            shaking_pre = add_presentation(shaking_pre, tp_cln_pre)

        return shaking_pre

class MyMatrix:
@staticmethod
def upper_reg2seq(ori_tensor):
    '''
    drop lower region and flat upper region to sequence
    :param ori_tensor: (batch_size, matrix_size, matrix_size, hidden_size)
    :return: (batch_size, matrix_size + ... + 1, hidden_size)
    '''
    tensor = ori_tensor.permute(0, 3, 1, 2).contiguous()
    uppder_ones = torch.ones([tensor.size()[-1], tensor.size()[-1]]).long().triu().to(ori_tensor.device)
    upper_diag_ids = torch.nonzero(uppder_ones.view(-1), as_tuple=False).view(-1)
    # flat_tensor: (batch_size, matrix_size * matrix_size, hidden_size)
    flat_tensor = tensor.view(tensor.size()[0], tensor.size()[1], -1).permute(0, 2, 1)
    tensor_upper = torch.index_select(flat_tensor, dim=1, index=upper_diag_ids)
    return tensor_upper

And I recommend you to use TPLinkerPlus. It is faster than the original TPlinker.

JaheimLee commented 3 years ago

Thanks! Yeah, I have been using the tplinker_plus. Your code did speed up the train steps, but may slow down the validation step in my case. Do you have a similar situation?

131250208 commented 3 years ago

That is weird. How much it slowed down? I did not do this comparison experiment. It may result from other reasons, e.g. two experiments are conducted in different environments, or CPUs and GPUs are shared with other progress in validation steps.