lucidrains / x-transformers

A simple but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.42k stars 377 forks source link

Multi Input/Output transformers #235

Closed RyanKim17920 closed 6 months ago

RyanKim17920 commented 6 months ago

I'm not 100% sure I implemented all the features correctly (especially the memory ones) but here is what I have. I've tested the code with these and there are no errors but I haven't checked every feature yet:

model = MultiIOTransformerWrapper(
    num_tokens=[8, 4, 5],
    max_seq_len=10,
    use_abs_pos_emb=True,
    max_mem_len=10,
    shift_mem_down=1,
    emb_dropout=0.1,
    post_emb_norm=True,
    pre_attn_layers=[
        Decoder(dim=2, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ),
        Decoder(dim=1, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ),
        Decoder(dim=1, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, )],
    post_attn_layers=[
        Decoder(dim=4, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ),
        Decoder(dim=8, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, )],
    memory_tokens_interspersed_every=[1, 1, 1],
    num_memory_tokens=[2, 1, 1],
    tie_embedding=True,
    l2norm_embed=True,
    emb_frac_gradient=0.1,
    attn_z_loss_weight=0.1,
    attn_layers=Decoder(
        dim=4,
        depth=1,
        heads=1,
        rotary_pos_emb=True,
        attn_flash=True,
        use_scalenorm=True,
        ff_glu=True,
    )
)
x = torch.Tensor(torch.randint(1, 3, (1, 10, 3))).float()
print(x)
print(model(x))

model = MultiIOTransformerWrapper(
    num_tokens=8,
    max_seq_len=10,
    use_abs_pos_emb=True,
    max_mem_len=10,
    shift_mem_down=1,
    emb_dropout=0.1,
    post_emb_norm=True,
    post_attn_layers=[
        Decoder(dim=1, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ),
        Decoder(dim=2, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, )],
    memory_tokens_interspersed_every=1,
    num_memory_tokens=2,
    tie_embedding=False,
    l2norm_embed=True,
    emb_frac_gradient=0.1,
    attn_z_loss_weight=0.1,
    attn_layers=Decoder(
        dim=4,
        depth=1,
        heads=1,
        rotary_pos_emb=True,
        attn_flash=True,
        use_scalenorm=True,
        ff_glu=True,
    )
)
x = torch.Tensor(torch.randint(1, 3, (1, 10, 3))).float()
print(x)
print(model(x))

model = MultiIOTransformerWrapper(
    num_tokens=[8, 4, 5],
    max_seq_len=10,
    use_abs_pos_emb=True,
    max_mem_len=10,
    logits_dim=[1, 2],
    shift_mem_down=1,
    emb_dropout=0.1,
    post_emb_norm=True,
    pre_attn_layers=[
        Decoder(dim=2, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ),
        Decoder(dim=1, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ),
        Decoder(dim=1, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, )],
    memory_tokens_interspersed_every=[1, 1, 1],
    num_memory_tokens=[2, 1, 1],
    tie_embedding=False,
    l2norm_embed=True,
    emb_frac_gradient=0.1,
    attn_z_loss_weight=0.1,
    attn_layers=Decoder(
        dim=4,
        depth=1,
        heads=1,
        rotary_pos_emb=True,
        attn_flash=True,
        use_scalenorm=True,
        ff_glu=True,
    )
)
x = torch.Tensor(torch.randint(1, 3, (1, 10, 3))).float()
print(x)
print(model(x))
RyanKim17920 commented 6 months ago

There seems to be some conflicts, I'll create a new pull request so that these issues are fixed.