Closed joaogui1 closed 3 years ago
Thanks, just a heads we'll get back to you sometime next week. I think what you plan to implement makes sense, but we'll discuss more in details next week (might bring in more reviewers for expertise on this particular topic).
Hey @ebrevdo I made the implementation multi-head, added the option for masking, and made the einsum comments clearer. Now I think I just need to:
Is that it?
Change the dot products to linear (I'm not sure I should do it, I've seen conflicting definitions)
Can someone answer this?
Change the dot products to linear (I'm not sure I should do it, I've seen conflicting definitions)
Can someone answer this?
Pinging @ebrevdo @david-berthelot
I don't really know the answer.
The way I typically solve such situations is by doing two versions: one using linear, one using einsum and then I can compare to make a call on which one we like best.
With einsum the code reads like this:
scores = jn.einsum('thd,Thd->htT', q, k) * functional.rsqrt(head_dim)
if mask is not None:
scores = scores * mask - 1e10 * (1 - mask)
weights = functional.softmax(scores)
attention = jn.einsum('htT,Thd->thd', weights, v)
Without einsum this is how it will look like:
q = q.transpose(1, 2, 0)
k = k.transpose(1, 0, 2)
scores = q.dot(k) * functional.rsqrt(head_dim)
if mask is not None:
scores = scores * mask - 1e10 * (1 - mask)
weights = functional.softmax(scores)
v = v.transpose(1, 0, 2)
attention = weights.dot(v)
attention = attention.transpose(1, 0, 2)
Though my question was not about einsum vs dot product, but rather between using a matrix multiplication to create the embeddings (x -> q, k, v) vs using a linear layer (matmul + bias) to accomplish it. When reading "Attention is All you Need" I had understood we used just a matmul, but in both implementations I listed in my first comment they used linear layers.
einsum
looks readable enough to me, concerning biases I don't have enough experience to answer properly.
@aterzis-google @ebrevdo Can you provide some feedback please?
einsum
looks readable enough to me, concerning biases I don't have enough experience to answer properly.@aterzis-google @ebrevdo Can you provide some feedback please?
I would follow the implementation in trax for this; trax was written by creators of transformer and their impl should be treated as ground truth.
https://github.com/google/trax/blob/master/trax/layers/attention.py
Ok folks, I was trying to finish this and realized I just don't know how (specially how to deal with tensor inputs to dense layers), I think I'm not as familiar with the library as I thought. I'm deeply sorry for wasting your time
No worries, we'll still be happy to have your contributions for other tasks if you have the time and motivation to contribute!
Adds attention, per #61 So, first I'm really sorry about taking so long, but college got complicated in the pandemic and I wasted a lot of time getting organized. Also, Attention is a quite general concept, and even implementations of the same type of attention differ significantly (haiku, flax) So @david-berthelot and @aterzis-google I would like to ask a few questions just to make sure my implementation is going in the right direction