Closed prajjwal1 closed 4 years ago
Hi! Thank you for your interest!
Yes, setting it as a Parameter is sufficient since entmax_bisect
already takes care of the gradient computation for you.
I'm not really sure about your code snippet so maybe this can help:
class EntmaxAlpha(torch.nn.Module):
def __init__(self, head_count, n_iter=25, dim=0):
super(EntmaxAlpha, self).__init__()
self.dim = dim
self.n_iter = n_iter
self.alpha_chooser = AlphaChooser(head_count)
def forward(self, X):
batch_size, head_count, query_len, key_len = X.size()
X = X.view(-1, key_len)
self.alpha = self.alpha_chooser()
expanded_alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
expanded_alpha = expanded_alpha.expand((batch_size, -1, query_len))
expanded_alpha = expanded_alpha.contiguous().view(-1)
p_star = entmax_bisect(X, expanded_alpha, self.n_iter)
return p_star.view(batch_size, head_count, query_len, -1)
We used this EntmaxAlpha
just as you would use nn.Softmax
in a regular multi-head attention implementation (particularly, we used OpenNMT).
The code snippet above should answer your question. You can compute the attention probabilities for the several heads in parallel (even if they have different alphas).
Note that p.mean()
is a constant since p
always sums to 1. Since it's a constant, calling backward()
on this will get you zero gradients.
This is learned automatically via backpropagation! :)
Hi, Firstly thanks for releasing the code of your paper. I had some queries: Here's what I'm doing:
How do I learn
alpha
correctly ? Will setting it as a parameter be sufficient, but in the paper you mentioned it cannot be solved simply by autograd differentiation.Will
p
be representative of attention scores from all the heads ?I'm only using first element of
alpha()
, how do I include all 12 values ofalpha()
while getting output scores from entmax ? Do I run a loop for all the elements of alpha and then take the mean, but that would increase computation time.What does 'val' signify here ?
How do you learn shape of each attention head in your work ?
Could you please answer these queries, any suggestion would be helpful . Thanks !