state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.7k stars 1.06k forks source link

Feat: Add the support for non-learnable RMS norm for large-scale training in `mamba_inner_fn` #543

Open younesbelkada opened 1 month ago

younesbelkada commented 1 month ago

Hi Albert Gu and Tri Dao,

First of all, thank you for this package. We would like to upstream some changes that were needed to train the FalconMamba-7B model using the mamba kernels.

This PR introduces a way to pass non learnable RMS norm weights in order to normalize B, C and dt states as per our training procedure.

Another way could be to initialize weight in rms_norm_forward with torch.ones_like, but I'd prefer to force users to pass the non learnable parameters themselves to avoid multiple tensor initialization at each call of mamba_inner_fn, there might be a way to call the rms norm forward without having the need to pass RMS weights which I am not sure.

On transformers side, we would call the interface with the following:

        # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
        self.register_buffer("b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False)
        self.register_buffer("dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False)
        self.rms_eps = config.mixer_rms_eps

    def cuda_kernels_forward(
        self,
        hidden_states: torch.Tensor,
        cache_params: Optional[MambaCache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
    ):
        # 1. Gated MLP's linear projection
        projected_states = self.in_proj(hidden_states).transpose(1, 2)

        if self.training and cache_params is None:  # Doesn't support outputting the states -> used for training
            contextualized_states = mamba_inner_fn(
                projected_states,
                conv1d_weight=self.conv1d.weight,
                conv1d_bias=self.conv1d.bias if self.use_conv_bias else None,
                x_proj_weight=self.x_proj.weight,
                delta_proj_weight=self.dt_proj.weight,
                out_proj_weight=self.out_proj.weight,
                out_proj_bias=self.out_proj.bias.float() if self.use_bias else None,
                A=-torch.exp(self.A_log.float()),
                B=None,  # input-dependent B
                C=None,  # input-dependent C
                D=self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                b_rms_weight=self.b_c_rms,
                c_rms_weight=self.b_c_rms,
                dt_rms_weight=self.dt_rms,
                b_c_dt_rms_eps=self.rms_eps
            )

Thank you very much in advance ! @tridao @albertfgu

younesbelkada commented 1 month ago

A simple snippet to reproduce the current issue:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, FalconMambaForCausalLM

model_id = "tiiuae/falcon-mamba-7b"
text = "Hello today we are going to"

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained(model_id)

inputs = tok(text, return_tensors="pt").to(0)

with torch.no_grad():
    logits = torch.argmax(model(**inputs).logits, dim=-1)

print(tok.batch_decode(logits))

model.train()
lm_logits = model(**inputs).logits
next_token = torch.argmax(lm_logits, dim=-1)

print(tok.batch_decode(logits))
loss = (1 - lm_logits).mean()
loss.backward()
younesbelkada commented 1 month ago

Hi @tridao @albertfgu I made an alternative PR in HF transformers: https://github.com/huggingface/transformers/pull/33195 where I simply copied over the kernels there. Let me know if you see any issue potentially merging this PR in mamba-ssm - thanks !