erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Loss increases randomly #61

Closed infamix closed 6 months ago

infamix commented 6 months ago

Hello once again, I am seeing some weird behavior with my loss whenever I use EaSyDel for fine-tuning, no matter the dataset. image

These are my training args:

train_args = TrainArguments(
    model_class=type(model),
    configs_to_init_model_class=configs_to_init_model_class,
    custom_rule=config.get_partition_rules(True),
    model_name='EasyDelLLama2',
    num_train_epochs=1,
    learning_rate=4e-05,
    learning_rate_end=1.5e-06,
    warmup_steps=156,
    optimizer='adamw',
    scheduler='warm_up_linear',
    weight_decay=0.01,
    total_batch_size=32,
    max_steps=None,
    do_train=True,
    do_eval=False,
    backend='tpu',
    max_length=max_length,
    gradient_checkpointing='nothing_saveable',
    sharding_array=(1, -1, 1),
    use_pjit_attention_force=False,
    gradient_accumulation_steps=1,
    remove_ckpt_after_load=True,
    ids_to_pop_from_dataset=['token_type_ids'],
    loss_remat='',
    is_left_padded=True,
    dtype=jax.numpy.bfloat16
)

trainer = CausalLMTrainer(
    train_args,
    dataset_train,
    ckpt_path=None
)

Training dataset is Open-Platypus (left-padded) and the model is Sheared-Llama-2.7B.

infamix commented 6 months ago

image Similar loss graph with the same dataset but slightly different training hyperparameters.

erfanzar commented 6 months ago

I have some questions:

  1. which version of TPUs are you using and how much does it have?
  2. which dataset you are using in order to train the model

the issue you are facing right now is not caused by EasyDel it's most likely caused by your training hyperparameters

image

all the graph models above are LLama2-7B except clean-sound-25 (Llama2-13B)

I have trained more than 30 Models and I haven't seen any issues like that

if the both model and dataset you are using Open-Source you can give them to me so I can check and train that for you

infamix commented 6 months ago

Training dataset is Open-Platypus (instruct fine-tuning), I am using Kaggle's TPU v3-8.

erfanzar commented 6 months ago

are you using 8 or custom bits to train this model?

cause I don't think there be any other ways to train 7B model with 128GB device VRAM with total_batch_size=32

infamix commented 6 months ago

are you using 8 or custom bits to train this model?

cause I don't think there be any other ways to train 7B model with 128GB device VRAM with total_batch_size=32

Hi, it's the 2.7B Llama from Princeton-NLP. No custom bits either.

infamix commented 6 months ago

Let me show you my full training code:

from EasyDel import TrainArguments, CausalLMTrainer, AutoEasyDelModelForCausalLM
from transformers import AutoTokenizer
import jax
import flax
from datasets import load_dataset

model_id = 'bn22/Sheared-LLaMA-2.7B-Sharded'
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token
sequence_max_length = 2048

#@title Synthia dataset preparation
from datasets import load_dataset
dataset = load_dataset("garage-bAInd/Open-Platypus", split = "train")

def prompt_formatting_alpaca_to_sharegpt(example):
    return {"conversations": [
                    {"from": "system", "value": example['input']},
                    {"from": "human", "value": example['instruction']},
                    {"from": "gpt", "value": example['output']},
                ],}

dataset = dataset.map(prompt_formatting_alpaca_to_sharegpt, num_proc=4)

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = []
    mapper = {"system" : "SYSTEM:", "human" : "USER:", "gpt" : "ASSISTANT:"}
    end_mapper = {"system" : "\n\n", "human" : "\n", "gpt" : f"{tokenizer.eos_token}\n"}
    for convo in convos:
        text = "".join(f"{mapper[(turn := x['from'])]} {x['value']}{end_mapper[turn]}" for x in convo)
        texts.append(text)
    return { "text" : texts, }

columns_to_remove = dataset.column_names
dataset = dataset.map(formatting_prompts_func, batched = True, remove_columns=columns_to_remove)

dataset_train_raw = dataset

def tokenize_function(examples):
    return tokenizer(examples["text"], padding='max_length', max_length=sequence_max_length)

dataset_train = dataset_train_raw.map(tokenize_function, batched=True, num_proc=4, remove_columns=dataset_train_raw.column_names)

# dataset should only contain numerical information for Model such as input_id, attention_mask , ...
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
    model_id,
    dtype=jax.numpy.bfloat16,
    param_dtype=jax.numpy.bfloat16,
    precision=jax.lax.Precision('fastest'),
    device=jax.devices('cpu')[0],  # Load JAX Model and initialize or load Parameters on CPU 
    # The Rest of kwargs here will be passed to AutoModelForCausalLM huggingface such as this device_map
    device_map='auto'
)
config = model.config

# this part of code is only for making model faster and more optimized 
config.freq_max_position_embeddings = config.max_position_embeddings
config.max_position_embeddings = 4096
config.c_max_position_embeddings = config.max_position_embeddings
config.use_pjit_attention_force = False  # disable pjit attention force is recommended in case of using MP = 1 in sharding Mesh

max_length = sequence_max_length

configs_to_init_model_class = {
    'config': config,
    'dtype': jax.numpy.bfloat16,
    'param_dtype': jax.numpy.bfloat16,
    'input_shape': (1, 1)
}

max_length = sequence_max_length

configs_to_init_model_class = {
    'config': config,
    'dtype': jax.numpy.bfloat16,
    'param_dtype': jax.numpy.bfloat16,
    'input_shape': (1, 1)
}

train_args = TrainArguments(
    model_class=type(model),
    configs_to_init_model_class=configs_to_init_model_class,
    custom_rule=config.get_partition_rules(True),
    model_name='EasyDelLLama2',
    num_train_epochs=1,
    learning_rate=4e-05,
    learning_rate_end=1e-05,
    warmup_steps=78,
    optimizer='adamw',
    scheduler='warm_up_linear',
    weight_decay=0.01,
    total_batch_size=32,
    max_steps=None,
    do_train=True,
    do_eval=False,
    backend='tpu',
    max_length=max_length,
    gradient_checkpointing='nothing_saveable',
    sharding_array=(1, -1, 1),
    use_pjit_attention_force=False,
    gradient_accumulation_steps=1,
    remove_ckpt_after_load=True,
    ids_to_pop_from_dataset=['token_type_ids'],
    loss_remat='',
    is_left_padded=True,
    dtype=jax.numpy.bfloat16
)

trainer = CausalLMTrainer(
    train_args,
    dataset_train,
    ckpt_path=None
)

output = trainer.train(flax.core.FrozenDict({'params': params}))

saved_model_location = f"{str(train_args.get_path())}/{output.last_save_file_name}"

print("Hey im Here in case you want to load me :", saved_model_location)

### Let Convert Model TO HF/PyTorch

from EasyDel.transform import llama_easydel_to_hf

config.rope_theta = 10000
config.attention_bias = False
model = llama_easydel_to_hf(saved_model_location, config=config)

# Here's your Huggingface Torch Llama
model = model.half()

# Save model and tokenizer
new_model = "/kaggle/working/Sheared-LLaMA-2.7B-Synthia"
model.save_pretrained(new_model)
tokenizer.save_pretrained(new_model)
erfanzar commented 6 months ago

thank you for providing your training code I'll try and train your model to debug that and I didn't know about the 2.7B pre-trained model that's something cool

and I guess you should change the optimizer and the parameters related to that

like use warm_up_cosin with learning rate 9-e5 and train that for 4 epochs

first of all, I thought you were working on the beta branch (4D mesh support causes wrong predictions in some cases) but it seems like you are on the main branch and the code stability is good enough!

erfanzar commented 6 months ago

this issue is caused by your dataset and system prompt because you are training this model only for 1 epoch and first half of the data don't have system data and the second half most likely to have system prompt try shuffling dataset or using more epochs for train model

infamix commented 6 months ago

this issue is caused by your dataset and system prompt because you are training this model only for 1 epoch and first half of the data don't have system data and the second half most likely to have system prompt try shuffling dataset or using more epochs for train model

Oh, thanks for the reply. Will try now!

infamix commented 6 months ago

this issue is caused by your dataset and system prompt because you are training this model only for 1 epoch and first half of the data don't have system data and the second half most likely to have system prompt try shuffling dataset or using more epochs for train model

That was precisely the issue, once again I should have been more attentive!