erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

ValueError: Dict key mismatch; expected keys: ['transformer']; #92

Closed jchauhan closed 5 months ago

jchauhan commented 5 months ago

Getting the following error while finetuning gpt2 model

To Reproduce

^[[36mTime Took to Complete Task configure dataloaders (microseconds) : ^[[97m0.33593177795410156^[[0m^[[0m
^[[36mTime Took to Complete Task configure Model ,Optimizer ,Scheduler and Config (microseconds) : ^[[97m677.0551204681396^[[0m^[[0m
^[[36mTime Took to Complete Task configure functions and sharding them (microseconds) : ^[[97m790.2214527130127^[[0m^[[0m
^[[31mAction : ^[[0mSharding Passed Parameters
Traceback (most recent call last):
  File "/home/xxx/research/transformers/train_gpt_easydel.py", line 92, in <module>
    output = trainer.train(flax.core.FrozenDict({"params": params}))
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/trainer/causal_language_model_trainer.py", line 669, in train
    sharded_state, shard_fns, gather_fns = self.init_state(
  File "/home/xxx/research/EasyDeL/lib/python/EasyDel/trainer/causal_language_model_trainer.py", line 585, in init_state
    params = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 243, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/home/xxx/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/tree_util.py", line 243, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Dict key mismatch; expected keys: ['transformer']; dict: {'transformer': {'wte': {'embedding': array([[-0.11010301, -0.03926672,  0.03310751, ..., -0.1363697 ,

Example Code

import jax.numpy
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

huggingface_repo_id_or_path = "gpt2"
max_length = 512
trained_model_name = "chnageme"
easydel_trained_model_name = f"{trained_model_name}.easydel"
training_data_files="changeme.json"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name=easydel_trained_model_name,
    num_train_epochs=3,
    configs_to_init_model_class=configs_to_init_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=1,
    max_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in fully FSDP automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)

def ultra_chat_prompting_process(
        data_chunk
):
    return {"prompt": data_chunk['train']}

tokenization_process = lambda data_chunk: tokenizer(
    data_chunk["prompt"],
    add_special_tokens=False,
    max_length=max_length,
    padding="max_length"
)

dataset = load_dataset("json", data_files=training_data_files)
dataset_train = dataset["train"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
    tokenization_process,
    num_proc=12,
    remove_columns=dataset_train.column_names
)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
erfanzar commented 5 months ago

i ran the same code and training started just fine

erfanzar commented 5 months ago

this issue is being closed due to no response has been given