google / objax

Apache License 2.0
769 stars 77 forks source link

Initial dot product attention #110

Closed joaogui1 closed 3 years ago

joaogui1 commented 4 years ago

Adds attention, per #61 So, first I'm really sorry about taking so long, but college got complicated in the pandemic and I wasted a lot of time getting organized. Also, Attention is a quite general concept, and even implementations of the same type of attention differ significantly (haiku, flax) So @david-berthelot and @aterzis-google I would like to ask a few questions just to make sure my implementation is going in the right direction

  1. I think I will implement a dot product attention, a multi-head attention and a masked attention, is that ok?
  2. What do you think of the dot product attention implementation? What do you think I need to change? Thanks for the patience and opportunity.
david-berthelot commented 4 years ago

Thanks, just a heads we'll get back to you sometime next week. I think what you plan to implement makes sense, but we'll discuss more in details next week (might bring in more reviewers for expertise on this particular topic).

joaogui1 commented 3 years ago
  1. What kind of unit tests? Just comparing a couple inputs with the desired outputs? 2 and 3: Done
joaogui1 commented 3 years ago

Hey @ebrevdo I made the implementation multi-head, added the option for masking, and made the einsum comments clearer. Now I think I just need to:

Is that it?

joaogui1 commented 3 years ago

Change the dot products to linear (I'm not sure I should do it, I've seen conflicting definitions)

Can someone answer this?

joaogui1 commented 3 years ago

Change the dot products to linear (I'm not sure I should do it, I've seen conflicting definitions)

Can someone answer this?

Pinging @ebrevdo @david-berthelot

david-berthelot commented 3 years ago

I don't really know the answer.

The way I typically solve such situations is by doing two versions: one using linear, one using einsum and then I can compare to make a call on which one we like best.

joaogui1 commented 3 years ago

With einsum the code reads like this:

scores = jn.einsum('thd,Thd->htT', q, k) * functional.rsqrt(head_dim)
if mask is not None:
    scores = scores * mask - 1e10 * (1 - mask)
weights = functional.softmax(scores)
attention = jn.einsum('htT,Thd->thd', weights, v)

Without einsum this is how it will look like:

q = q.transpose(1, 2, 0)
k = k.transpose(1, 0, 2)
scores = q.dot(k) * functional.rsqrt(head_dim)
if mask is not None:
    scores = scores * mask - 1e10 * (1 - mask)
weights = functional.softmax(scores)
v = v.transpose(1, 0, 2)
attention = weights.dot(v)
attention = attention.transpose(1, 0, 2)
joaogui1 commented 3 years ago

Though my question was not about einsum vs dot product, but rather between using a matrix multiplication to create the embeddings (x -> q, k, v) vs using a linear layer (matmul + bias) to accomplish it. When reading "Attention is All you Need" I had understood we used just a matmul, but in both implementations I listed in my first comment they used linear layers.

david-berthelot commented 3 years ago

einsum looks readable enough to me, concerning biases I don't have enough experience to answer properly.

@aterzis-google @ebrevdo Can you provide some feedback please?

ebrevdo commented 3 years ago

einsum looks readable enough to me, concerning biases I don't have enough experience to answer properly.

@aterzis-google @ebrevdo Can you provide some feedback please?

I would follow the implementation in trax for this; trax was written by creators of transformer and their impl should be treated as ground truth.

https://github.com/google/trax/blob/master/trax/layers/attention.py

joaogui1 commented 3 years ago

Ok folks, I was trying to finish this and realized I just don't know how (specially how to deal with tensor inputs to dense layers), I think I'm not as familiar with the library as I thought. I'm deeply sorry for wasting your time

david-berthelot commented 3 years ago

No worries, we'll still be happy to have your contributions for other tasks if you have the time and motivation to contribute!