coreylowman / dfdx

Deep learning in Rust, with shape checked tensors and neural networks
Other
1.71k stars 98 forks source link

Attention mask in TransformerDecoderBlock? #590

Open ifsheldon opened 1 year ago

ifsheldon commented 1 year ago

Hi! I've been trying to porting nanoGPT to Rust with dfdx. The transformer module is awesome! but it seems an important trick is missing, which is the attention mask in TransformerDecoderBlock. I took a look at the below lines and didn't find anything about attention mask. Did I miss anything?

https://github.com/coreylowman/dfdx/blob/cbe38a54fad2f58023cbceb0ea9d9e889a34e7f2/src/nn/transformer/decoder.rs#L187

https://github.com/coreylowman/dfdx/blob/cbe38a54fad2f58023cbceb0ea9d9e889a34e7f2/src/nn/transformer/mha.rs#L130

For attention masks, you can refer to Neural Networks: Zero to Hero - Let's build GPT: from scratch, in code, spelled out and the documentation of torch.nn.MultiheadAttention.forward and torch.nn.functional.scaled_dot_product_attention.

opfromthestart commented 1 year ago

I think the choose function may do what you are asking, it allows you to use a boolean tensor to choose element-wise between two given tensors.

jafioti commented 1 year ago

Fairly certian the default impl doesn't have a causal attention mask. You'll need to add it yourself. Here's what I did to the forward function:

assert_eq!(k.shape().0, v.shape().0);
let s1 = q.shape().0;
let s2 = k.shape().0;
let v = self.w_v.try_forward(v.retaped::<T>())?;
let v = v.try_reshape_like(&(s2, H, V / H)).unwrap()?;
let v = v.try_permute::<_, Axes3<1, 0, 2>>()?;

let k = self.w_k.try_forward(k.retaped::<T>())?;
let k = k.try_reshape_like(&(s2, H, K / H)).unwrap()?;
let k = k.try_permute::<_, Axes3<1, 2, 0>>()?;

let q = self.w_q.try_forward(q)?;
let q = q.try_reshape_like(&(s1, H, K / H)).unwrap()?;
let q = q.try_permute::<_, Axes3<1, 0, 2>>()?;

// Get weights
let scalar: E = E::ONE / E::from_usize(K / H).unwrap().sqrt();
let weights = q.try_matmul(k)?.try_mul(scalar)?;
let mut mask = vec![E::zero(); s1.size() * s2.size()];
for i in 0..s1.size() {
    for j in i+1..s2.size() {
        mask[i *  s1.size() + j] = -E::infinity();
    }
}
let mask: Tensor<(S1, S2), _, _> = weights.device.try_tensor_from_vec(mask, (s1, s2)).unwrap();
let weights = weights.try_add(mask.try_broadcast_like(&(H, s1, s2))?)?;
let weights = weights.try_softmax::<Axis<2>>()?;

// Get new tokens
let tokens = weights.try_matmul(v)?;
let tokens = tokens.try_permute::<_, Axes3<1, 0, 2>>()?;
let tokens = tokens.try_reshape_like(&(s1, Const::<V>)).unwrap()?;

self.w_o.try_forward(tokens)
ifsheldon commented 1 year ago
let weights = weights.try_add(mask.try_broadcast_like(&(H, s1, s2))?)?;

I don't know how dfdx handles adding infinity, but in theory this is not sufficient since addition doesn't block gradient flow in backprop although it blocks attention in the forward pass.

jafioti commented 1 year ago

@ifsheldon It shouldn't block gradient flow, but the gradients will be subtracted by inf so in practice they should go to zero. Proper masking would be better.

The best would be to build it directly into an MHA Cuda kernel.

coreylowman commented 1 year ago

Nope you didn't miss anything, mask isn't currently supported. Luckily we can add it in a non-breaking way by just adding more impl Module for both MultiHeadAttention/Decoder/Transformer that accept an additional tensor input.

Regarding infinity, I know huggingface usually uses the float min value (https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py#L166). I'm not sure if this is any difference in practice than using infinity?

Related to this is #436 which is being worked on right now.

ifsheldon commented 1 year ago

It shouldn't block gradient flow, but the gradients will be subtracted by inf so in practice they should go to zero.

Regarding infinity, I know huggingface usually uses the float min value (https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py#L166). I'm not sure if this is any difference in practice than using infinity?

The reason why subtracting inf works is that it is immediately followed by a softmax since e^(x - inf) = e^x / e^inf = 0. However, I guess, compared to select or masking, subtracting inf makes (a naive) autograd track a lot of unnecessary compute node since subtraction and softmax do not block compute flow.

Perhaps we can have a functional like torch.nn.functional.scaled_dot_product_attention?

coreylowman commented 1 year ago

autograd track a lot of unnecessary compute node since subtraction and softmax do not block compute flow.

Since we are doing softmax regardless, the only extra computation would be the sub op forward/backwards right? Or am I missing something. I think actually masking and subtracting inf should result in the same amount of computation? Either way the mask or sub(inf) doesn't need a gradient, so the only extra operation is the forward pass

ifsheldon commented 1 year ago

I think actually masking and subtracting inf should result in the same amount of computation?

Note that the below is purely theoretical on paper in terms of tracking compute flow. A sophisticated autograd system that has sophisticated handling on inf and operator fusion should be able to get around the issue.

comparison

On the left is the full attention, middle causal attention with mask, right causal attention with subtracting inf. The light green box is softmax. The black lines in the grid mean gradient flow. You can see that the black circles are disconnected by masking, and the red circles are subtracted by inf. In theory, although the forward results and backward gradients are the same in these two methods, the number of gradient flow routes should be halved in the middle case in implementation. Therefore, the computation on paper can be halved as well.

But as I said, if autograd detects the combination of subtracting inf and softmax and fuses these two ops, then these two cases may be actually the same in implementation.

coreylowman commented 1 year ago

Ahh I see, thanks for the graphic. At the moment dfdx does not support operator fusion, so they would both be the same. This is an interesting direction to go in though, I've been thinking about fusion a lot lately with optimization on my mind.

jafioti commented 1 year ago

@coreylowman Fusion would be such a huge win with transformer MHA, just looking at the speed differences between a fused flash attention kernel over a naïve one it’s staggering.

How were you thinking of approaching this though? One of the downsides (and upsides) of rust is it’s much less dynamic. In PyTorch I think they can parse the whole tree of a module and rewrite it at runtime

coreylowman commented 1 year ago

Let's move discussion of that into the issue I just made