Closed infamix closed 6 months ago
Similar loss graph with the same dataset but slightly different training hyperparameters.
I have some questions:
the issue you are facing right now is not caused by EasyDel it's most likely caused by your training hyperparameters
all the graph models above are LLama2-7B except clean-sound-25 (Llama2-13B)
I have trained more than 30 Models and I haven't seen any issues like that
if the both model and dataset you are using Open-Source you can give them to me so I can check and train that for you
Training dataset is Open-Platypus (instruct fine-tuning), I am using Kaggle's TPU v3-8.
are you using 8 or custom bits to train this model?
cause I don't think there be any other ways to train 7B model with 128GB device VRAM with total_batch_size=32
are you using 8 or custom bits to train this model?
cause I don't think there be any other ways to train 7B model with 128GB device VRAM with total_batch_size=32
Hi, it's the 2.7B Llama from Princeton-NLP. No custom bits either.
Let me show you my full training code:
from EasyDel import TrainArguments, CausalLMTrainer, AutoEasyDelModelForCausalLM
from transformers import AutoTokenizer
import jax
import flax
from datasets import load_dataset
model_id = 'bn22/Sheared-LLaMA-2.7B-Sharded'
tokenizer = AutoTokenizer.from_pretrained(
model_id,
padding_side="left",
add_eos_token=True,
add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token
sequence_max_length = 2048
#@title Synthia dataset preparation
from datasets import load_dataset
dataset = load_dataset("garage-bAInd/Open-Platypus", split = "train")
def prompt_formatting_alpaca_to_sharegpt(example):
return {"conversations": [
{"from": "system", "value": example['input']},
{"from": "human", "value": example['instruction']},
{"from": "gpt", "value": example['output']},
],}
dataset = dataset.map(prompt_formatting_alpaca_to_sharegpt, num_proc=4)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = []
mapper = {"system" : "SYSTEM:", "human" : "USER:", "gpt" : "ASSISTANT:"}
end_mapper = {"system" : "\n\n", "human" : "\n", "gpt" : f"{tokenizer.eos_token}\n"}
for convo in convos:
text = "".join(f"{mapper[(turn := x['from'])]} {x['value']}{end_mapper[turn]}" for x in convo)
texts.append(text)
return { "text" : texts, }
columns_to_remove = dataset.column_names
dataset = dataset.map(formatting_prompts_func, batched = True, remove_columns=columns_to_remove)
dataset_train_raw = dataset
def tokenize_function(examples):
return tokenizer(examples["text"], padding='max_length', max_length=sequence_max_length)
dataset_train = dataset_train_raw.map(tokenize_function, batched=True, num_proc=4, remove_columns=dataset_train_raw.column_names)
# dataset should only contain numerical information for Model such as input_id, attention_mask , ...
model, params = AutoEasyDelModelForCausalLM.from_pretrained(
model_id,
dtype=jax.numpy.bfloat16,
param_dtype=jax.numpy.bfloat16,
precision=jax.lax.Precision('fastest'),
device=jax.devices('cpu')[0], # Load JAX Model and initialize or load Parameters on CPU
# The Rest of kwargs here will be passed to AutoModelForCausalLM huggingface such as this device_map
device_map='auto'
)
config = model.config
# this part of code is only for making model faster and more optimized
config.freq_max_position_embeddings = config.max_position_embeddings
config.max_position_embeddings = 4096
config.c_max_position_embeddings = config.max_position_embeddings
config.use_pjit_attention_force = False # disable pjit attention force is recommended in case of using MP = 1 in sharding Mesh
max_length = sequence_max_length
configs_to_init_model_class = {
'config': config,
'dtype': jax.numpy.bfloat16,
'param_dtype': jax.numpy.bfloat16,
'input_shape': (1, 1)
}
max_length = sequence_max_length
configs_to_init_model_class = {
'config': config,
'dtype': jax.numpy.bfloat16,
'param_dtype': jax.numpy.bfloat16,
'input_shape': (1, 1)
}
train_args = TrainArguments(
model_class=type(model),
configs_to_init_model_class=configs_to_init_model_class,
custom_rule=config.get_partition_rules(True),
model_name='EasyDelLLama2',
num_train_epochs=1,
learning_rate=4e-05,
learning_rate_end=1e-05,
warmup_steps=78,
optimizer='adamw',
scheduler='warm_up_linear',
weight_decay=0.01,
total_batch_size=32,
max_steps=None,
do_train=True,
do_eval=False,
backend='tpu',
max_length=max_length,
gradient_checkpointing='nothing_saveable',
sharding_array=(1, -1, 1),
use_pjit_attention_force=False,
gradient_accumulation_steps=1,
remove_ckpt_after_load=True,
ids_to_pop_from_dataset=['token_type_ids'],
loss_remat='',
is_left_padded=True,
dtype=jax.numpy.bfloat16
)
trainer = CausalLMTrainer(
train_args,
dataset_train,
ckpt_path=None
)
output = trainer.train(flax.core.FrozenDict({'params': params}))
saved_model_location = f"{str(train_args.get_path())}/{output.last_save_file_name}"
print("Hey im Here in case you want to load me :", saved_model_location)
### Let Convert Model TO HF/PyTorch
from EasyDel.transform import llama_easydel_to_hf
config.rope_theta = 10000
config.attention_bias = False
model = llama_easydel_to_hf(saved_model_location, config=config)
# Here's your Huggingface Torch Llama
model = model.half()
# Save model and tokenizer
new_model = "/kaggle/working/Sheared-LLaMA-2.7B-Synthia"
model.save_pretrained(new_model)
tokenizer.save_pretrained(new_model)
thank you for providing your training code I'll try and train your model to debug that and I didn't know about the 2.7B pre-trained model that's something cool
and I guess you should change the optimizer and the parameters related to that
like use warm_up_cosin with learning rate 9-e5 and train that for 4 epochs
first of all, I thought you were working on the beta branch (4D mesh support causes wrong predictions in some cases) but it seems like you are on the main branch and the code stability is good enough!
this issue is caused by your dataset and system prompt because you are training this model only for 1 epoch and first half of the data don't have system data and the second half most likely to have system prompt try shuffling dataset or using more epochs for train model
this issue is caused by your dataset and system prompt because you are training this model only for 1 epoch and first half of the data don't have system data and the second half most likely to have system prompt try shuffling dataset or using more epochs for train model
Oh, thanks for the reply. Will try now!
this issue is caused by your dataset and system prompt because you are training this model only for 1 epoch and first half of the data don't have system data and the second half most likely to have system prompt try shuffling dataset or using more epochs for train model
That was precisely the issue, once again I should have been more attentive!
Hello once again, I am seeing some weird behavior with my loss whenever I use EaSyDel for fine-tuning, no matter the dataset.![image](https://github.com/erfanzar/EasyDeL/assets/34374618/038582dc-31e7-4767-b416-79d7d4116236)
These are my training args:
Training dataset is Open-Platypus (left-padded) and the model is Sheared-Llama-2.7B.