lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.08k stars 141 forks source link

A small question regarding `softmax_kernel` #36

Closed tianylin98 closed 3 years ago

tianylin98 commented 3 years ago

First things first, greate repo.

I'm trying to understand the renormalizing in softmax_kernel, tho:


if is_query:
    data_dash = ratio * (
    torch.exp(data_dash - diag_data -
                       torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
else:
    data_dash = ratio * (
                            torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps)

In this segment of code, an argument is_query is used to distinguish the difference in computation.

I reckon that this part is to alleviate numerical problems. I wonder why the computation for query features and key features should be different (in that the max op is different)?

Really appreciate it if you could shed a light on this question so I could understand this.

tianylin98 commented 3 years ago

Seem to get this now. Probably because the normalization is done w.r.t the query axis...