tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
396 stars 48 forks source link

Fuse residual add with attention norm in Llama #11664

Open kpaigwar opened 3 weeks ago

kpaigwar commented 3 weeks ago

Presently in transformer decoder, we do

h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward.forward(self.ffn_norm(h))

We have chains of decoders

# layer1
h1 = x0 + self.attention.forward(self.attention_norm(x0), start_pos, freqs_cis, mask)
out1 = h1 + self.feed_forward.forward(self.ffn_norm(h1))
# layer 2
h2 = x1 + self.attention.forward(self.attention_norm(out1), start_pos, freqs_cis, mask)
out2 = h2 + self.feed_forward.forward(self.ffn_norm(h2))

We can certainly fuse residual add with attention_norm as below where x0 is token embedding and mlp_out0 is a no-op

# layer1
h1 = x0 + self.attention.forward(self.attention_norm(x0 + mlp_out0), start_pos, freqs_cis, mask) 
mlp_out1 = self.feed_forward.forward(self.ffn_norm(h1))
# layer 2
h2 = x1 + self.attention.forward(self.attention_norm(h1 + mlp_out1), start_pos, freqs_cis, mask)
mlp_out2 =  self.feed_forward.forward(self.ffn_norm(h2))
kpaigwar commented 3 weeks ago

fyi @cglagovichTT @johanna-rock-tt @uaydonat @xuncaiTT