kyegomez / Zamba

Implementation of the Paper: "Zamba: A Compact 7B SSM Hybrid Model" in Pytorch
MIT License
6 stars 0 forks source link

[BUG] #13

Open lapetiteclef opened 1 week ago

lapetiteclef commented 1 week ago

Was thinking to try this Zamba impl. Read code and a loop looks odd. Is this some odd global sharing scheme with this fractal shared weight part? Do not understand it, like spamming the same weight with same input and drop all except last output? Or typo? If typo need:

    out = x
    for layer in self.layers:
        out = layer(out)

def forward(self, x) -> Tensor:
    # Embed tokens
    x = self.embed(x)

    if self.post_embed_norm is not False:
        x = self.norm(x)

    for layer in self.layers:
        out = layer(x)

    # return OutputHead(self.dim, 1, self.vocab_size)(x)
    if self.output_head_on is not False:
        out = OutputHead(self.dim, 1, self.vocab_size)(x)
    else:
        return out

Upvote & Fund

Fund with Polar

github-actions[bot] commented 1 week ago

Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.