Open wj-Mcat opened 4 years ago
I have find some example code.
class SlotAttention(nn.Module):
def __init__(self, n_features=64):
super(SlotAttention, self).__init__()
self.attention = nn.Linear(n_features, n_features)
def forward(self, x):
"""
:param x: hidden states of LSTM (batch_size, seq_len, hidden_size)
:return: slot attention vector of size (batch_size, seq_len, hidden_size)
attention = softmax(x * linear(x)) * x
"""
weights = self.attention(x) # (batch_size, seq_len, hidden_size) - temporary weight
weights = torch.matmul(weights, torch.transpose(x, 1, 2)) # (batch_size, hidden_size, hidden_size) - att matrix
weights = F.softmax(weights, dim=2)
output = torch.matmul(weights, x)
return output
class IntentAttention(nn.Module):
def __init__(self, n_features=64):
super(IntentAttention, self).__init__()
self.attention = nn.Linear(n_features, n_features)
def forward(self, x):
"""
:param x: hidden states of LSTM (batch_size, seq_len, hidden_size)
:return: intent vector of size (batch_size, hidden_size)
"""
weights = self.attention(x) # (batch_size, seq_len, hidden_size) - temporary weight
# output = torch.matmul(x, weights)
weights = torch.matmul(weights, torch.transpose(x, 1, 2)) # (batch_size, seq_len, seq_len) - att matrix
weights = F.softmax(weights, dim=2)
output = torch.matmul(weights, x)
output = torch.sum(output, 1)
return output
But this computation logic is different from cited paper.
The cited paper attention mechanism is :
But, your attention mechanism is so simple: https://github.com/ZephyrChenzf/SF-ID-Network-For-NLU/blob/67f0bc7339d007d48f3c2d64ba41c8b0d668cea2/train.py#L113
I can't find out
SlotAttention
andIntentAttention
code. Any one to help me?