I'm trying to understand the renormalizing in softmax_kernel, tho:
if is_query:
data_dash = ratio * (
torch.exp(data_dash - diag_data -
torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
else:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps)
In this segment of code, an argument is_query is used to distinguish the difference in computation.
I reckon that this part is to alleviate numerical problems. I wonder why the computation for query features and key features should be different (in that the max op is different)?
Really appreciate it if you could shed a light on this question so I could understand this.
First things first, greate repo.
I'm trying to understand the renormalizing in
softmax_kernel
, tho:In this segment of code, an argument
is_query
is used to distinguish the difference in computation.I reckon that this part is to alleviate numerical problems. I wonder why the computation for query features and key features should be different (in that the max op is different)?
Really appreciate it if you could shed a light on this question so I could understand this.