Closed JaheimLee closed 3 years ago
You could change the for loop
in the handshaking kernel to matrix operations. Try this code:
class HandshakingKernel(nn.Module):
def __init__(self, hidden_size, shaking_type, only_look_after=True):
super().__init__()
self.shaking_type = shaking_type
self.only_look_after = only_look_after
if "cat" in shaking_type:
self.cat_fc = nn.Linear(hidden_size * 2, hidden_size)
if "cln" in shaking_type:
self.tp_cln = LayerNorm(hidden_size, hidden_size, conditional=True)
if "lstm" in shaking_type:
assert only_look_after is True
self.lstm4span = nn.LSTM(hidden_size,
hidden_size,
num_layers=1,
bidirectional=False,
batch_first=True)
def forward(self, seq_hiddens):
'''
seq_hiddens: (batch_size, seq_len, hidden_size_x)
return:
if only look after:
shaking_hiddenss: (batch_size, (1 + seq_len) * seq_len / 2, hidden_size); e.g. (32, 5+4+3+2+1, 5)
else:
shaking_hiddenss: (batch_size, seq_len * seq_len, hidden_size)
'''
seq_len = seq_hiddens.size()[1]
guide = seq_hiddens[:, :, None, :].repeat(1, 1, seq_len, 1)
visible = guide.permute(0, 2, 1, 3)
shaking_pre = None
# pre_num = 0
def add_presentation(all_prst, prst):
if all_prst is None:
all_prst = prst
else:
all_prst += prst
return all_prst
if self.only_look_after:
if "lstm" in self.shaking_type:
batch_size, _, matrix_size, vis_hidden_size = visible.size()
# mask lower triangle
upper_visible = visible.permute(0, 3, 1, 2).triu().permute(0, 2, 3, 1).contiguous()
# visible4lstm: (batch_size * matrix_size, matrix_size, hidden_size)
visible4lstm = upper_visible.view(-1, matrix_size, vis_hidden_size)
span_pre, _ = self.lstm4span(visible4lstm)
span_pre = span_pre.view(batch_size, matrix_size, matrix_size, vis_hidden_size)
# drop lower triangle and convert matrix to sequence
# span_pre: (batch_size, shaking_seq_len, hidden_size)
span_pre = MyMatrix.upper_reg2seq(span_pre)
shaking_pre = add_presentation(shaking_pre, span_pre)
# guide, visible: (batch_size, shaking_seq_len, hidden_size)
guide = MyMatrix.upper_reg2seq(guide)
visible = MyMatrix.upper_reg2seq(visible)
if "cat" in self.shaking_type:
tp_cat_pre = torch.cat([guide, visible], dim=-1)
tp_cat_pre = torch.relu(self.cat_fc(tp_cat_pre))
shaking_pre = add_presentation(shaking_pre, tp_cat_pre)
if "cln" in self.shaking_type:
tp_cln_pre = self.tp_cln(visible, guide)
shaking_pre = add_presentation(shaking_pre, tp_cln_pre)
return shaking_pre
class MyMatrix:
@staticmethod
def upper_reg2seq(ori_tensor):
'''
drop lower region and flat upper region to sequence
:param ori_tensor: (batch_size, matrix_size, matrix_size, hidden_size)
:return: (batch_size, matrix_size + ... + 1, hidden_size)
'''
tensor = ori_tensor.permute(0, 3, 1, 2).contiguous()
uppder_ones = torch.ones([tensor.size()[-1], tensor.size()[-1]]).long().triu().to(ori_tensor.device)
upper_diag_ids = torch.nonzero(uppder_ones.view(-1), as_tuple=False).view(-1)
# flat_tensor: (batch_size, matrix_size * matrix_size, hidden_size)
flat_tensor = tensor.view(tensor.size()[0], tensor.size()[1], -1).permute(0, 2, 1)
tensor_upper = torch.index_select(flat_tensor, dim=1, index=upper_diag_ids)
return tensor_upper
And I recommend you to use TPLinkerPlus. It is faster than the original TPlinker.
Thanks! Yeah, I have been using the tplinker_plus. Your code did speed up the train steps, but may slow down the validation step in my case. Do you have a similar situation?
That is weird. How much it slowed down? I did not do this comparison experiment. It may result from other reasons, e.g. two experiments are conducted in different environments, or CPUs and GPUs are shared with other progress in validation steps.
Do you have any suggestions to speedup train step or inference step?