state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.59k stars 1.06k forks source link

Gradient explosion in Mamba2 training, norm and loss divergence #529

Open edwko opened 1 month ago

edwko commented 1 month ago

Hi, I'm experiencing an issue with clip_grad_norm_ and loss values while training Mamba2. After training for some time, the gradient norm starts to rapidly increase to infinity. If training continues, the loss eventually becomes NaN.

With gradient accumulation: grad

Training with no gradient accumulation: no_grad

Below is a simple training script that reproduces this issue. I'm wondering if I'm doing something incorrect:

from transformers import AutoTokenizer
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from tqdm import tqdm
import time
import torch
import torch.nn.functional as F
import requests
import polars as pl

device = torch.device("cuda")
tokenizer = AutoTokenizer.from_pretrained("mistral-community/Mistral-7B-v0.2")
cfg = MambaConfig(
    d_model = 1024,
    d_intermediate = 0,
    n_layer = 20,
    vocab_size = 32768,
    ssm_cfg = {"ngroups": 8, "layer": "Mamba2"},
    rms_norm = True,
    residual_in_fp32 = True,
    fused_add_norm = True,
    pad_vocab_size_multiple = 8,
    tie_embeddings = True)
model = MambaLMHeadModel(cfg).to(device)

def download_file(url, path):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    with open(path, 'wb') as file, tqdm(
        desc=path, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024
    ) as progress_bar:
        for data in response.iter_content(chunk_size=1024):
            size = file.write(data)
            progress_bar.update(size)

url = "https://huggingface.co/datasets/mlfoundations/dclm-baseline-1.0-parquet/resolve/main/filtered/OH_eli5_vs_rw_v2_bigram_200k_train/fasttext_openhermes_reddit_eli5_vs_rw_v2_bigram_200k_train/processed_data/global-shard_01_of_10/local-shard_0_of_10/shard_00000010_processed.parquet"
download_file(url, "temp.parquet")

df = pl.read_parquet("temp.parquet")['text']
tokens = []
enc_batch = []
for text in tqdm(df):
    enc_batch.append("<s>" + text + "</s>\n")
    if len(enc_batch) == 25:
        for t in tokenizer.batch_encode_plus(enc_batch, add_special_tokens=False)['input_ids']:
            tokens.extend(t)
        enc_batch = []
        if len(tokens) >= 5_000_000: 
            break
del df, enc_batch

block = 4096
gradient_accumulation = 12
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, betas=(0.9, 0.95))
tq = tqdm(range(0, len(tokens), block))
tloss = 0
model.train()
optimizer.zero_grad()
for p, i in enumerate(tq):
    src = torch.tensor([tokens[i:i+block]], dtype=torch.int64).to(device)
    tgt = torch.tensor([tokens[i+1:i+block+1]], dtype=torch.int64).to(device)
    if src.size()[-1] != block or tgt.size()[-1] != block: 
        break
    logits = model(src).logits
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt.view(-1))
    loss = loss / gradient_accumulation
    loss.backward()
    tloss += loss.item()
    if (p + 1) % gradient_accumulation == 0:
        clip_grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0).item()
        optimizer.step()
        optimizer.zero_grad()
        tq.set_description(f"loss: {tloss:.4f}, grad_norm: {clip_grad_norm:.4f} | p: {p} | i: {i}")
        tloss = 0
vasqu commented 1 month ago

Possibly related to #522

DanFosing commented 4 weeks ago

@edwko Could you try to decrease ngroups to 1 to see if the issue is related to the one I'm having?

edwko commented 4 weeks ago

@edwko Could you try to decrease ngroups to 1 to see if the issue is related to the one I'm having?

@DanFosing Yes, using ngroups 1, which is the default setting, the training process is stable. Processed over 30 billion tokens, and everything appears to be stable.