lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.57k stars 391 forks source link

Flash is not flash #144

Open liujuncn opened 1 year ago

liujuncn commented 1 year ago

I test Flash attention vs HF GPT2 with pytorch lightning warp. But it is slow than transformers.GPT2LMHeadModel with same config parameters. Not sure where I am going wrong?

image

Purple is x-transformers flash attn.

` class FlashAttentionLM(pl.LightningModule):

def __init__(self, config):
    super().__init__()
    model = TransformerWrapper(
        num_tokens = config.vocab_size,
        max_seq_len = config.seq_length,
        attn_layers = Decoder(
            dim = config.embd_size,
            depth = config.n_layer + 1,
            heads = 8,
            attn_flash = True
        )
    )
    self.model = AutoregressiveWrapper(model)

`

lucidrains commented 1 year ago

you are comparing the entire transformer implementation, not the attention mechanism itself