kyegomez / BitNet

Implementation of "BitNet: Scaling 1-bit Transformers for Large Language Models" in pytorch
https://discord.gg/qUtxnK2NMf
MIT License
1.55k stars 143 forks source link

[BUG] residual connection wrong? #32

Closed qianlong0502 closed 6 months ago

qianlong0502 commented 7 months ago

In bit_transformer.py:

class Transformer(nn.Module):
    def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
        for attn, ffn in zip(self.layers, self.ffn_layers):
            # print(x.shape)
            x, _ = attn(x, x, x, is_causal=True, *args, **kwargs)
            x = x + x
            x = ffn(x) + x
        return x

Is the line x = x + x wrong? This seems not a residual connection.

Upvote & Fund

Fund with Polar

kyegomez commented 6 months ago

@qianlong0502 fixed it now excuse me, thanks for the catch!