SkBlaz / san

Attention-based feature ranking for propositional data.
GNU General Public License v3.0
27 stars 9 forks source link

Question about the implementation in SAN #9

Closed APCunha closed 1 year ago

APCunha commented 1 year ago

Hello, I have a question regarding the implementation of SAN, especially the "Instance-level aggregations" as you call it in the paper. I assume we get those values when calling the get_attention function in which the return_sotfmax is True. In the forward_attention function though: https://github.com/SkBlaz/san/blob/4ab2397570c9131c9983eb53475760929a05a343/san/__init__.py#L67C1-L80C19 when return_softmax = True we get the different matrices for each head, then via placeholder we add them and then we divide them by the number of heads (performing the average) and only then we perform the softmax function. Is this correct? So, my question is when you do this aren't you actually first doing the average and only then performing the softmax on the averages and not the other way around of first applying the softmax to each head and then doing the average as seen in the paper: imagem

Shouldn't it be something like this (this way we get the instance attention matrix with return_softmax = False as well as the output):

def forward_attention(self, input_space, return_softmax=False):

    placeholder2 = torch.zeros(input_space.shape).to(self.device)

    placeholder = torch.zeros(input_space.shape).to(self.device)

    for k in range(len(self.multi_head)):

        if return_softmax:

            attended_matrix = self.multi_head[k](input_space)

        else:

            attention_matrix = self.softmax3(self.multi_head[k](input_space))

            attended_matrix = attention_matrix * input_space

        placeholder = torch.add(placeholder,attended_matrix)

        placeholder2 = torch.add(placeholder2, attention_matrix)

    placeholder /= len(self.multi_head)

    placeholder2 /= len(self.multi_head)

    out = placeholder

    if return_softmax:

        out = self.softmax(out)

    return out, placeholder2`

Like factoring out the X in this expression: imagem Am I overlooking something and this is a stupid question? Sorry if I put it in a confusing manner

SkBlaz commented 1 year ago

Hi @APCunha, thanks for the question, good point. We used softmax=T as default for get attention as this also works and is a bit more efficient - it's indeed not the same as f1 if calling the get_instance_attention directly. I suggest you open a PR with your change as it nicely entails both use cases (and removes the need for manual calls to forward attention method - such calls are currently needed to replicate the quoted formula as you perhaps observed.

APCunha commented 1 year ago

Done! However, do you think it would be better to do something just like this

def forward_attention(self, input_space):
    placeholder = torch.zeros(input_space.shape).to(self.device)
    placeholder2 = torch.zeros(input_space.shape).to(self.device) # Define placeholder2
    for k in range(len(self.multi_head)):
        attention_matrix = self.softmax3(self.multi_head[k](input_space))
        attended_matrix = attention_matrix * input_space
        placeholder = torch.add(placeholder,attended_matrix)
        placeholder2 = torch.add(placeholder2, attention_matrix)
    placeholder /= len(self.multi_head)
    placeholder2 /= len(self.multi_head)
    out = placeholder
    return out, placeholder2 # return placeholder2 in addition to out`
SkBlaz commented 1 year ago

merged the first one for now, don't see apparent need for this, however, feel free to open another PR and we can test it out