deep-spin / entmax

The entmax mapping and its loss, a family of sparse softmax alternatives.
MIT License
407 stars 43 forks source link

Replicating the behaviour from the paper #15

Closed prajjwal1 closed 4 years ago

prajjwal1 commented 4 years ago

Hi, Firstly thanks for releasing the code of your paper. I had some queries: Here's what I'm doing:

>>> att_scores = torch.rand(128, 12, 36, 1024)
>>> alpha = AlphaChooser(12)
>>> p = entmax_bisect(att_scores,alpha()[0])
>>> val = p.mean().backward()
  1. How do I learn alpha correctly ? Will setting it as a parameter be sufficient, but in the paper you mentioned it cannot be solved simply by autograd differentiation.

    >>> p = 0.
    >>> for i in range(att_scores.size(1)):
    >>>   p += entmax_bisect(X = att_scores,alpha = alpha()[i], dim=1)
    >>> p/= num_attention_heads
  2. Will p be representative of attention scores from all the heads ?

  3. I'm only using first element of alpha(), how do I include all 12 values of alpha() while getting output scores from entmax ? Do I run a loop for all the elements of alpha and then take the mean, but that would increase computation time.

  4. What does 'val' signify here ?

  5. How do you learn shape of each attention head in your work ?

Could you please answer these queries, any suggestion would be helpful . Thanks !

goncalomcorreia commented 4 years ago

Hi! Thank you for your interest!

  1. Yes, setting it as a Parameter is sufficient since entmax_bisect already takes care of the gradient computation for you.

  2. I'm not really sure about your code snippet so maybe this can help:

class EntmaxAlpha(torch.nn.Module):

    def __init__(self, head_count, n_iter=25, dim=0):
        super(EntmaxAlpha, self).__init__()
        self.dim = dim
        self.n_iter = n_iter
        self.alpha_chooser = AlphaChooser(head_count)

    def forward(self, X):
        batch_size, head_count, query_len, key_len = X.size()

        X = X.view(-1, key_len)
        self.alpha = self.alpha_chooser()
        expanded_alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
        expanded_alpha = expanded_alpha.expand((batch_size, -1, query_len))
        expanded_alpha = expanded_alpha.contiguous().view(-1)
        p_star = entmax_bisect(X, expanded_alpha, self.n_iter)

        return p_star.view(batch_size, head_count, query_len, -1)

We used this EntmaxAlpha just as you would use nn.Softmax in a regular multi-head attention implementation (particularly, we used OpenNMT).

  1. The code snippet above should answer your question. You can compute the attention probabilities for the several heads in parallel (even if they have different alphas).

  2. Note that p.mean() is a constant since p always sums to 1. Since it's a constant, calling backward() on this will get you zero gradients.

  3. This is learned automatically via backpropagation! :)