erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

oom when llama2-7b sft #163

Open kuangdao opened 1 week ago

kuangdao commented 1 week ago

i try to stf llama2-7b and oom, can it support fsdp or tensor parallel

kuangdao commented 1 week ago

who can tell me why ?

the error is :

截屏2024-06-20 14 55 08

and the code is : ''' from easydel import ( TrainArguments, AutoEasyDeLModelForCausalLM, EasyDeLOptimizers, EasyDeLSchedulers, EasyDeLGradientCheckPointers, SFTTrainer, conversations_formatting_function # i have added this one for newcomers so if they

don't know what's going on they can use this pre created prompter

) from datasets import load_dataset import flax from jax import numpy as jnp from transformers import AutoTokenizer

huggingface_repo_id_or_path = "/cfs/models/Llama2-Chinese-7b-Chat"

huggingface_repo_id_or_path="TinyLlama-1.1B-intermediate-step-1431k-3T" model, params = AutoEasyDeLModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 4096 tokenizer = AutoTokenizer.from_pretrained( huggingface_repo_id_or_path, trust_remote_code=True ) tokenizer.pad_token = tokenizer.eos_token configs_to_initialize_model_class = { "config": model.config, "dtype": jnp.bfloat16, "param_dtype": jnp.bfloat16, "input_shape": (1, 1) }

sharding_axis_dims = (1, -1, 1, 1)

train_arguments = TrainArguments( model_class=type(model), model_name="SFT-EasyDeL", num_train_epochs=3, configs_to_initialize_model_class=configs_to_initialize_model_class, learning_rate=5e-5, learning_rate_end=1e-6, optimizer=EasyDeLOptimizers.ADAMW, scheduler=EasyDeLSchedulers.WARM_UP_COSINE, weight_decay=0.01, total_batch_size=1, max_training_steps=None, # None to let trainer Decide do_train=True, fully_sharded_data_parallel=True,

sharding_array=sharding_axis_dims,

#step_partition_spec='fsdp',
do_eval=False,  # it's optional but supported 
backend="gpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_sequence_length=max_length,  # Note that you have to change this in the model config too
gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
# everything training will be in sequence and model parallel automatic and share data between devices
remove_ckpt_after_load=True,
gradient_accumulation_steps=8,
loss_re_mat="",
dtype=jnp.bfloat16

)

def prompter(sample): return [conversations_formatting_function(tokenizer, messages_field="messages")(sample)]

train_dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft")

data_dict = {'train':'./ultrachat_200k/ultrachat_200k-train_sft-00000-of-00003.arrow'}

train_dataset = load_dataset('arrow', data_files=data_dict)['train']

trainer = SFTTrainer( arguments=train_arguments, train_dataset=train_dataset, eval_dataset=None, # we don't have eval dataset rn :) tokenizer=tokenizer, dataset_text_field=None, formatting_func=prompter, packing=True, num_of_sequences=max_length, )

output = trainer.train(flax.core.FrozenDict({"params": params})) print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

''' i run scripts as such python train.py

kuangdao commented 1 week ago

i run it in A800 which 1 node and 8 gpus

erfanzar commented 1 week ago

hello @kuangdao and thanks for using EasyDeL, and sorry for late response you can try given code and it's using FSDP but you can also change to Sequence Parallelization.


from easydel import (
    TrainArguments,
    AutoEasyDeLModelForCausalLM,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers,
    SFTTrainer,
    PartitionAxis,
    conversations_formatting_function  # i have added this one for newcomers so if they
    # don't know what's going on they can use this pre created prompter
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

# huggingface_repo_id_or_path = "/cfs/models/Llama2-Chinese-7b-Chat"
dtype = jnp.bfloat16
block_size = 512
attn_mechanism = "sharded_vanilla"
partition_axis = PartitionAxis()
huggingface_repo_id_or_path = "TinyLlama-1.1B-intermediate-step-1431k-3T"
sharding_axis_dims = (1, -1, 1, 1)  # Change to 1,1,1,-1 for Sequence Sharding
max_length = 4096
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    huggingface_repo_id_or_path,
    param_dtype=dtype,
    dtype=dtype,
    input_shape=(8, 8),  # since you said you have 8 GPUs
    auto_shard_params=True,
    sharding_axis_dims=sharding_axis_dims,
    verbose_params=True,
    config_kwargs=dict(
        use_scan_mlp=False,
        attn_mechanism=attn_mechanism,
        partition_axis=partition_axis
    ),
    partition_axis=partition_axis,
)

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": dtype,
    "param_dtype": dtype,
    "input_shape": (8, 8)
}

# sharding_axis_dims = (1, -1, 1, 1)

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="SFT-EasyDeL",
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDeLOptimizers.ADAMW,
    scheduler=EasyDeLSchedulers.WARM_UP_COSINE,
    weight_decay=0.01,
    total_batch_size=8,  # Note if you are using FSDP you can't use batch size 1 since you have 8 GPUs
    max_training_steps=None,  # None to let trainer Decide
    do_train=True,
    fully_sharded_data_parallel=True,
    force_batch_and_gradient_accumulation_steps_calculation=False,
    do_eval=False,  # it's optional but supported
    backend="gpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_sequence_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=sharding_axis_dims,
    # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in sequence and model parallel automatic and share data between devices
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    use_pjit_attention_force=False,
    init_input_shape=(8, 8),
    dtype=dtype,
    param_dtype=dtype,
    step_start_point=0,
    do_last_save=False,
    do_shard_fns=False,
    track_memory=False,  # Install GO lang first ...
)

def prompter(sample):
    return [conversations_formatting_function(tokenizer, messages_field="messages")(sample)]

# train_dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft")

data_dict = {'train': './ultrachat_200k/ultrachat_200k-train_sft-00000-of-00003.arrow'}

train_dataset = load_dataset('arrow', data_files=data_dict)['train']

trainer = SFTTrainer(
    arguments=train_arguments,
    train_dataset=train_dataset,
    eval_dataset=None,  # we don't have eval dataset rn :)
    tokenizer=tokenizer,
    dataset_text_field=None,
    formatting_func=prompter,
    packing=True,
    num_of_sequences=max_length,
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
erfanzar commented 1 week ago

Is that fixed @kuangdao