Open kpaigwar opened 3 months ago
Presently in transformer decoder, we do
h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask) out = h + self.feed_forward.forward(self.ffn_norm(h))
We have chains of decoders
# layer1 h1 = x0 + self.attention.forward(self.attention_norm(x0), start_pos, freqs_cis, mask) out1 = h1 + self.feed_forward.forward(self.ffn_norm(h1)) # layer 2 h2 = x1 + self.attention.forward(self.attention_norm(out1), start_pos, freqs_cis, mask) out2 = h2 + self.feed_forward.forward(self.ffn_norm(h2))
We can certainly fuse residual add with attention_norm as below where x0 is token embedding and mlp_out0 is a no-op
x0
mlp_out0
# layer1 h1 = x0 + self.attention.forward(self.attention_norm(x0 + mlp_out0), start_pos, freqs_cis, mask) mlp_out1 = self.feed_forward.forward(self.ffn_norm(h1)) # layer 2 h2 = x1 + self.attention.forward(self.attention_norm(h1 + mlp_out1), start_pos, freqs_cis, mask) mlp_out2 = self.feed_forward.forward(self.ffn_norm(h2))
fyi @cglagovichTT @johanna-rock-tt @uaydonat @xuncaiTT
Presently in transformer decoder, we do
We have chains of decoders
We can certainly fuse residual add with attention_norm as below where
x0
is token embedding andmlp_out0
is a no-op