erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
191 stars 23 forks source link

training does not start using latest easydel #146

Closed IvoryTower800 closed 5 months ago

IvoryTower800 commented 5 months ago

Hi, I the training doesn't start with recent update. I tried different models and parameters. It only show below information. Then it stop running.

Besides, I need to manually set below parameters to load the model. Otherwise, there is a value error.

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
                                                            auto_shard_params=True,
                                                           sharding_axis_dims=(1,-1,1,1),
                                                           input_shape=(8,max_length))
Warning : In case of using `finetune = True` and Passing `checkpoint_path = None` you should pass parameters in train function
wandb: Currently logged in as: ivorytower800. Use `wandb login --relogin` to force relogin
Tracking run with wandb version 0.16.6
Run data is saved locally in /kaggle/working/wandb/run-20240424_161106-rims0awm
Syncing run [woven-snowflake-158](https://wandb.ai/ivorytower800/EasyDeL-writer-gemma-2b/runs/rims0awm) to [Weights & Biases](https://wandb.ai/ivorytower800/EasyDeL-writer-gemma-2b) ([docs](https://wandb.me/run))
View project at https://wandb.ai/ivorytower800/EasyDeL-writer-gemma-2b
View run at https://wandb.ai/ivorytower800/EasyDeL-writer-gemma-2b/runs/rims0awm
erfanzar commented 5 months ago

Hi, Are you sure you are calling tariner.train() ?

erfanzar commented 5 months ago

Can you share the code?

IvoryTower800 commented 5 months ago

@erfanzar Sure, below is the code.

from EasyDel import (
    TrainArguments,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers,
    SFTTrainer,
    CausalLanguageModelTrainer,
    conversations_formatting_function
)
from datasets import load_dataset,load_from_disk
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

max_length = 8192

huggingface_repo_id_or_path = "google/gemma-2b-it"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
                                                            auto_shard_params=True,
                                                           sharding_axis_dims=(1,-1,1,1),
                                                           input_shape=(8,max_length))
# model.config.add_basic_configurations(
#     attn_mechanism="wise_ring",  # Using Flash Attention here you can simply just set this to normal or ring
# )

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (8, max_length)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="writer-gemma-2b",
    num_train_epochs=1,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    custom_rule=model.config.get_partition_rules(True),
    learning_rate=0.000001846,
    learning_rate_end=2e-7,
    max_sequence_length=max_length,
    optimizer=EasyDelOptimizers.ADAMW,
    scheduler=EasyDelSchedulers.LINEAR,
    weight_decay=0.01,
    warmup_steps=0,
    total_batch_size=8,
    save_optimizer_state=False,
    max_training_steps=None,
    do_train=True,
    do_eval=False,
    backend="tpu",
    max_length=max_length,
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    init_input_shape=(8, max_length),
    gradient_accumulation_steps=1,
    loss_re_mat="",
    dtype=jnp.bfloat16,
    training_time="8H",
    track_memory=True,
    force_batch_and_gradient_accumulation_steps_calculation=True,
    use_wandb=True, # This disable WANB usage
)

dataset_train = load_from_disk('/kaggle/input/sadlfkjaslkgma8192')
desired_indices = range(0, len(dataset_train))
dataset_train = dataset_train.select(desired_indices)

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
#     checkpoint_path='/root/' + ckpt_name
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
erfanzar commented 5 months ago

if you are using auto_shard_params=True for loading model you should disable do_shard_fns in TrainingArguments

erfanzar commented 5 months ago

@IvoryTower800 is that fixed?

IvoryTower800 commented 5 months ago

@erfanzar Hi, sorry for the late reply. it was fixed. Thank you!