Open DanFosing opened 1 month ago
Can you give a short script to reproduce the issue? E.g. for these specific tensors, the gradients are wrong / very large.
I unfortunately can't give you more detailed info right now, other than the script I'm using:
torch.cuda.empty_cache()
seed = 42
set_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def print_nb_trainable_params(model):
att_params = 0
mamba_params = 0
mlp_params = 0
embedding_params = 0
total_params = 0
for name, param in model.named_parameters():
if "embeddings" in name or "lm_head" in name:
embedding_params += np.prod(param.shape)
elif "attention" in name or (("q_proj" in name or "k_proj" in name or "v_proj" in name or "o_proj" in name) and "mixer" in name):
att_params += np.prod(param.shape)
elif "mixer" in name and "attention" not in name:
mamba_params += np.prod(param.shape)
elif "mlp" in name:
mlp_params += np.prod(param.shape)
total_params += np.prod(param.shape)
print(
f"Num params in attention: {att_params / 1_000_000:.2f}M,"
f" Mamba: {mamba_params / 1_000_000:.2f}M,"
f" MLP: {mlp_params / 1_000_000:.2f}M,"
f" embedding: {embedding_params / 1_000_000:.2f}M"
)
print(f"Total Number of parameters: {total_params / 1_000_000:.2f}M")
dataset = load_dataset("SprykAI/dclm-onlytrain-20B", keep_in_memory=False)
eval_dataset = load_dataset("SprykAI/dclm-val-1100rows", split="validation")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.model_max_length = 2048
tokenizer.pad_token = tokenizer.eos_token
batch_size = 32
def tokenize_function(examples):
return tokenizer(examples["text"])
column_names = dataset.features
tokenized_train_dataset = dataset.map(
tokenize_function, remove_columns=column_names, batched=True)
tokenized_eval_dataset = eval_dataset.map(
tokenize_function, remove_columns=column_names, batched=True)
config_class = Mamba2Config
model_class = Mamba2ForCausalLM
use_cache = False
config = config_class(
vocab_size=32000,
hidden_size=1024,
num_hidden_layers=12,
expand=2,
head_dim=64,
n_groups=8,
num_heads=32,
state_size=128,
use_cache=False,
is_training=True,
)
config.vocab_size = len(tokenizer)
model = model_class(config)
model.to("cuda", dtype=torch.bfloat16)
print_nb_trainable_params(model)
def group_texts(examples):
# Concatenate all texts.
block_size = 2048
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# Calculate the number of full chunks
num_full_chunks = total_length // block_size
# Split by chunks of block_size and drop the last chunk if necessary
result = {k: [concatenated_examples[k][i * block_size : (i + 1) * block_size] for i in range(num_full_chunks)]
for k in concatenated_examples.keys()}
# Add labels (assuming they are the same as input_ids)
result["labels"] = result["input_ids"].copy()
return result
training_args = TrainingArguments(
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_checkpointing=False,
gradient_accumulation_steps=16,
load_best_model_at_end=False,
num_train_epochs=1,
report_to=["wandb"],
eval_strategy="steps",
learning_rate=4e-4,
fp16=not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_bf16_supported(),
bf16_full_eval=torch.cuda.is_bf16_supported(),
fp16_full_eval=not torch.cuda.is_bf16_supported(),
logging_steps=2,
adam_beta1=0.9,
adam_beta2=0.95,
adam_epsilon=1e-7,
optim="adamw_torch",
save_total_limit=400,
eval_steps=2,
save_steps=500,
save_strategy="steps",
weight_decay=0.1,
max_grad_norm=1.0,
seed=seed,
lr_scheduler_type="cosine_with_min_lr",
warmup_ratio=0.01,
lr_scheduler_kwargs={"min_lr_rate": 0.1},
push_to_hub=True,
hub_private_repo=True,
output_dir=f"mambatraining")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
t_out = trainer.train()
eval_out = trainer.evaluate()
trainer.save_model(f"./mambatraining")```
dclm-onlytrain-20B and dclm-validation datasets are just preprocessed dclm dataset (I included processing steps (tokenization and splitting into chunks, so you can see what exactly did I do to those datasets)
Whenever I set n groups to value higher than 1, the gradients explode after like 70-100M tokens. I have no idea if it's my config causing problems or is it mamba-2 itself (I'm using huggingface implementation and opened my issue there: https://github.com/huggingface/transformers/issues/32570 but as you can see I got redirected to you)