Open liujuncn opened 1 year ago
I test Flash attention vs HF GPT2 with pytorch lightning warp. But it is slow than transformers.GPT2LMHeadModel with same config parameters. Not sure where I am going wrong?
Purple is x-transformers flash attn.
` class FlashAttentionLM(pl.LightningModule):
def __init__(self, config): super().__init__() model = TransformerWrapper( num_tokens = config.vocab_size, max_seq_len = config.seq_length, attn_layers = Decoder( dim = config.embd_size, depth = config.n_layer + 1, heads = 8, attn_flash = True ) ) self.model = AutoregressiveWrapper(model)
`
you are comparing the entire transformer implementation, not the attention mechanism itself
I test Flash attention vs HF GPT2 with pytorch lightning warp. But it is slow than transformers.GPT2LMHeadModel with same config parameters. Not sure where I am going wrong?
Purple is x-transformers flash attn.
` class FlashAttentionLM(pl.LightningModule):
`