Open GanjinZero opened 2 years ago
I guess the reason is from k = create_kernel(k, is_query = False) from FastAttention.forward, in the softmax_kernel operation, it has line *data_dash = ratio (torch.exp(data_dash - diag_data - torch.amax(data_dash, dim=(-1, -2), keepdim=True)) + eps). torch.amax(data_dash, dim=(-1, -2), keepdim=True))** contains information for later time hidden states, and this information is passed to previous time hidden states.
@GanjinZero oh shoot, yea, those maxes are for numerical stability, but i think they should be detached https://github.com/lucidrains/performer-pytorch/commit/fc8b78441b1e27eb5d9b01fc738a8772cee07127 can you let me know if this resolves the issue on your end?
I want to apply autoregressive performer for decoding.
The output is tensor(0.0003, device='cuda:0', grad_fn=).
If I turn off the use_scalenorm, the output is tensor(0.0085, device='cuda:0', grad_fn=).
This shows the inconsistent output for autoregressive performer.