lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.08k stars 141 forks source link

torch.max(data_dash) bug #76

Closed martinpflaum closed 2 years ago

martinpflaum commented 2 years ago

Hello, i really like your implementation, but i think there is a mistake in line 109 of performer pytorch. There torch.max returns only one value meaning it s also calculated across batches and attention heads. Where as in https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py in line 108 this is not the case last_dims_t+attention_dims_t is a tuple! this is true since both last_dims_t and attention_dims_t are tuples

data_dash = ratio * ( jnp.exp(data_dash - diag_data - jnp.max( data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) + eps)

i didn t run the code aboth but i think it is much more likely that they didn t calculated the maximum across batches and also didn t calculate the maximum across multiple attention heads.

lucidrains commented 2 years ago

@martinpflaum Hi Martin, thank you for catching this error (though it should be harmless, as it is only used for numerical stability. Do you want to check version 1.1.1 and see if this fixes the problem?

martinpflaum commented 2 years ago

Hi, yes looks good 👍