Closed RyanKim17920 closed 6 months ago
I'm not 100% sure I implemented all the features correctly (especially the memory ones) but here is what I have. I've tested the code with these and there are no errors but I haven't checked every feature yet:
model = MultiIOTransformerWrapper( num_tokens=[8, 4, 5], max_seq_len=10, use_abs_pos_emb=True, max_mem_len=10, shift_mem_down=1, emb_dropout=0.1, post_emb_norm=True, pre_attn_layers=[ Decoder(dim=2, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ), Decoder(dim=1, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ), Decoder(dim=1, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, )], post_attn_layers=[ Decoder(dim=4, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ), Decoder(dim=8, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, )], memory_tokens_interspersed_every=[1, 1, 1], num_memory_tokens=[2, 1, 1], tie_embedding=True, l2norm_embed=True, emb_frac_gradient=0.1, attn_z_loss_weight=0.1, attn_layers=Decoder( dim=4, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ) ) x = torch.Tensor(torch.randint(1, 3, (1, 10, 3))).float() print(x) print(model(x)) model = MultiIOTransformerWrapper( num_tokens=8, max_seq_len=10, use_abs_pos_emb=True, max_mem_len=10, shift_mem_down=1, emb_dropout=0.1, post_emb_norm=True, post_attn_layers=[ Decoder(dim=1, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ), Decoder(dim=2, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, )], memory_tokens_interspersed_every=1, num_memory_tokens=2, tie_embedding=False, l2norm_embed=True, emb_frac_gradient=0.1, attn_z_loss_weight=0.1, attn_layers=Decoder( dim=4, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ) ) x = torch.Tensor(torch.randint(1, 3, (1, 10, 3))).float() print(x) print(model(x)) model = MultiIOTransformerWrapper( num_tokens=[8, 4, 5], max_seq_len=10, use_abs_pos_emb=True, max_mem_len=10, logits_dim=[1, 2], shift_mem_down=1, emb_dropout=0.1, post_emb_norm=True, pre_attn_layers=[ Decoder(dim=2, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ), Decoder(dim=1, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ), Decoder(dim=1, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, )], memory_tokens_interspersed_every=[1, 1, 1], num_memory_tokens=[2, 1, 1], tie_embedding=False, l2norm_embed=True, emb_frac_gradient=0.1, attn_z_loss_weight=0.1, attn_layers=Decoder( dim=4, depth=1, heads=1, rotary_pos_emb=True, attn_flash=True, use_scalenorm=True, ff_glu=True, ) ) x = torch.Tensor(torch.randint(1, 3, (1, 10, 3))).float() print(x) print(model(x))
There seems to be some conflicts, I'll create a new pull request so that these issues are fixed.
I'm not 100% sure I implemented all the features correctly (especially the memory ones) but here is what I have. I've tested the code with these and there are no errors but I haven't checked every feature yet: