HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
46 stars 6 forks source link

ALiBi Convolution #48

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

Currently, we're implicitly adding a locality bias to our model by using convolutions and giving more gradient signals (via MuParemtrization) to small convolutions. However, ALiBi demonstrated that adding a locality bias to attention can help it converge and extrapolate significantly. ALiBi is the only position embedding that works at all scales, so we should take a closer look at adding it to our codebase.\ We already have something akin to ALiBi with our QRNN (with #7 hopefully allowing us to run it more frequently), but that might be not enough bias. We could add a bias inspired by ALiBi into our convolution weights to enforce locality further.\ One approach could be as simple as adding a scaling factor to the convolution kernel to penalise long-context interactions by giving them less weight and gradient signal during training. To implement this, we'd just have to expand our existing "parameter_variance" scalar to a tensor of shape [Kernel, 1, 1], which could contain any bias such as the linear bias (1 + jnp.arange(kernel).reshape(-1, 1, 1)) / sum(range(kernel + 1)).

This issue tracks the progress of such an implementation and compares our new ALiBi Convolution with the current baseline model.

ClashLuke commented 2 years ago

addressed by #56. for more comments, see the PR