state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.65k stars 1.06k forks source link

Exploding gradients if ngroups is higher than 1. #522

Open DanFosing opened 1 month ago

DanFosing commented 1 month ago

Whenever I set n groups to value higher than 1, the gradients explode after like 70-100M tokens. I have no idea if it's my config causing problems or is it mamba-2 itself (I'm using huggingface implementation and opened my issue there: https://github.com/huggingface/transformers/issues/32570 but as you can see I got redirected to you)

tridao commented 1 month ago

Can you give a short script to reproduce the issue? E.g. for these specific tensors, the gradients are wrong / very large.

DanFosing commented 1 month ago

I unfortunately can't give you more detailed info right now, other than the script I'm using:


torch.cuda.empty_cache()
seed = 42
set_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def print_nb_trainable_params(model):
    att_params = 0
    mamba_params = 0
    mlp_params = 0
    embedding_params = 0
    total_params = 0

    for name, param in model.named_parameters():
        if "embeddings" in name or "lm_head" in name:
            embedding_params += np.prod(param.shape)
        elif "attention" in name or (("q_proj" in name or "k_proj" in name or "v_proj" in name or "o_proj" in name) and "mixer" in name):
            att_params += np.prod(param.shape)
        elif "mixer" in name and "attention" not in name:
            mamba_params += np.prod(param.shape)
        elif "mlp" in name:
            mlp_params += np.prod(param.shape)
        total_params += np.prod(param.shape)

    print(
        f"Num params in attention: {att_params / 1_000_000:.2f}M,"
        f" Mamba: {mamba_params / 1_000_000:.2f}M,"
        f" MLP: {mlp_params / 1_000_000:.2f}M,"
        f" embedding: {embedding_params / 1_000_000:.2f}M"
    )
    print(f"Total Number of parameters: {total_params / 1_000_000:.2f}M")
dataset = load_dataset("SprykAI/dclm-onlytrain-20B", keep_in_memory=False)
eval_dataset = load_dataset("SprykAI/dclm-val-1100rows", split="validation")

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.model_max_length = 2048
tokenizer.pad_token = tokenizer.eos_token
batch_size = 32
def tokenize_function(examples):
    return tokenizer(examples["text"])

column_names = dataset.features
tokenized_train_dataset = dataset.map(
    tokenize_function, remove_columns=column_names, batched=True)

tokenized_eval_dataset = eval_dataset.map(
    tokenize_function, remove_columns=column_names, batched=True)

config_class = Mamba2Config
model_class = Mamba2ForCausalLM
use_cache = False
config = config_class(
    vocab_size=32000,
    hidden_size=1024,
    num_hidden_layers=12,
    expand=2,
    head_dim=64,
    n_groups=8,
    num_heads=32,
    state_size=128,
    use_cache=False,
    is_training=True,
)
config.vocab_size = len(tokenizer)

model = model_class(config)
model.to("cuda", dtype=torch.bfloat16)
print_nb_trainable_params(model)

def group_texts(examples):
    # Concatenate all texts.
    block_size = 2048
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    # Calculate the number of full chunks
    num_full_chunks = total_length // block_size

    # Split by chunks of block_size and drop the last chunk if necessary
    result = {k: [concatenated_examples[k][i * block_size : (i + 1) * block_size] for i in range(num_full_chunks)]
              for k in concatenated_examples.keys()}

    # Add labels (assuming they are the same as input_ids)
    result["labels"] = result["input_ids"].copy()

    return result

training_args = TrainingArguments(
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_checkpointing=False,
    gradient_accumulation_steps=16,
    load_best_model_at_end=False,
    num_train_epochs=1,
    report_to=["wandb"],
    eval_strategy="steps",
    learning_rate=4e-4,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    bf16_full_eval=torch.cuda.is_bf16_supported(),
    fp16_full_eval=not torch.cuda.is_bf16_supported(),
    logging_steps=2,
    adam_beta1=0.9,
    adam_beta2=0.95,
    adam_epsilon=1e-7,
    optim="adamw_torch",
    save_total_limit=400,
    eval_steps=2,
    save_steps=500,
    save_strategy="steps",
    weight_decay=0.1,
    max_grad_norm=1.0,
    seed=seed,
    lr_scheduler_type="cosine_with_min_lr",
    warmup_ratio=0.01,
    lr_scheduler_kwargs={"min_lr_rate": 0.1},
    push_to_hub=True,
    hub_private_repo=True,
    output_dir=f"mambatraining")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=eval_dataset,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

t_out = trainer.train()
eval_out = trainer.evaluate()
trainer.save_model(f"./mambatraining")```
DanFosing commented 1 month ago

dclm-onlytrain-20B and dclm-validation datasets are just preprocessed dclm dataset (I included processing steps (tokenization and splitting into chunks, so you can see what exactly did I do to those datasets)