sustcsonglin / flash-linear-attention

Efficient implementations of state-of-the-art linear attention models in Pytorch and Triton
MIT License
1.24k stars 66 forks source link

Minor mamba-2 fixes #40

Closed DanFosing closed 1 month ago

DanFosing commented 1 month ago

Some minor mamba-2 fixes (previously there could be some issues with peft, now it should be fixed). And I also updated mamba-2 default state_size to 64 as set in https://github.com/state-spaces/mamba . Mamba-2 is a bit more optimized in terms of states so it's default state_size should be 64 or 128 instead of 16 like for mamba-1.

Training benchmark output:


Initializing mamba2 model from the config:
Mamba2Config {
  "bos_token_id": 1,
  "chunk_size": 256,
  "conv_kernel": 4,
  "eos_token_id": 2,
  "expand": 2,
  "fuse_cross_entropy": true,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.1,
  "layer_norm_epsilon": 1e-05,
  "model_type": "mamba2",
  "n_groups": 8,
  "norm_before_gate": true,
  "num_heads": 64,
  "num_hidden_layers": 48,
  "pad_token_id": 0,
  "rescale_prenorm_residual": false,
  "residual_in_fp32": true,
  "rms_norm": true,
  "state_size": 16,
  "tie_word_embeddings": false,
  "time_step_floor": 0.0001,
  "time_step_limit": [
    0.0,
    Infinity
  ],
  "time_step_max": 0.1,
  "time_step_min": 0.001,
  "time_step_rank": 128,
  "transformers_version": "4.43.4",
  "use_bias": false,
  "use_cache": true,
  "use_conv_bias": true,
  "vocab_size": 32000
}

Mamba2ForCausalLM(
  (backbone): Mamba2Model(
    (embeddings): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-47): 48 x Mamba2Block(
        (norm): RMSNorm(2048, eps=1e-05)
        (mixer): Mamba2Mixer(
          (act): SiLU()
          (conv1d): Conv1d(4352, 4352, kernel_size=(4,), stride=(1,), padding=(3,), groups=4352)
          (in_proj): Linear(in_features=2048, out_features=8512, bias=False)
          (norm): FusedRMSNormSwishGate(4096, eps=1e-05)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
      )
    )
    (norm_f): RMSNorm(2048, eps=1e-05)
  )
  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)
Number of parameters in total: 1371839488 (1.28GiB)
Allocated memory after initialization: 2.56GiB
Max memory allocated: 37.25GiB: 100%|████████████████████████████████████████████████████████████████████████████████| 16/16 [00:43<00:00,  2.69s/it]
Thoughput:   19337.15 tokens/s: 100%|████████████████████████████████████████████████████████████████████████████████| 32/32 [00:27<00:00,  1.18it/s]```