erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
208 stars 25 forks source link

Error while training GPT2 on the kaggle #98

Closed jchauhan closed 9 months ago

jchauhan commented 9 months ago

Describe the bug Error while training gpt2 on kaggle

/root
Downloading base model...
<class 'EasyDel.modules.gpt2.modelling_gpt2_flax.FlaxGPT2LMHeadModel'>
<class 'EasyDel.modules.gpt2.gpt2_configuration.GPT2Config'>
Downloading data files: 100%|██████████████████| 1/1 [00:00<00:00, 11214.72it/s]
Extracting data files: 100%|████████████████████| 1/1 [00:00<00:00, 1438.38it/s]
Generating train split: 186074 examples [00:00, 353355.15 examples/s]
Map (num_proc=12): 100%|██████| 186074/186074 [00:03<00:00, 47310.49 examples/s]
/usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'.
  table = cls._concat_blocks(blocks, axis=0)
Map (num_proc=12):   0%|                      | 0/186074 [00:00<?, ? examples/s]/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
/usr/local/lib/python3.10/site-packages/datasets/table.py:1387: FutureWarning: promote has been superseded by promote_options='default'.
  return cls._concat_blocks(pa_tables_to_concat_vertically, axis=0)
Map (num_proc=12): 100%|███████| 186074/186074 [00:21<00:00, 8523.98 examples/s]
Warning : In case of using `finetune = True` and Passing `checkpoint_path = None` you should pass parameters in train function
wandb: Currently logged in as: jchauhan (safedep). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.2
wandb: Run data is saved locally in /root/wandb/run-20240201_154611-g3pguwrw
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run avid-sky-14
wandb: ⭐️ View project at [https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel](https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel%3C/span%3E)
wandb: 🚀 View run at [https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel/runs/g3pguwrw](https://wandb.ai/safedep/EasyDeL-raven_gpt2.easydel/runs/g3pguwrw%3C/span%3E)
Time Took to Complete Task configure dataloaders (microseconds) : 0.4191398620605469
Time Took to Complete Task configure Model ,Optimizer ,Scheduler and Config (microseconds) : 597.6324081420898
Time Took to Complete Task configure functions and sharding them (microseconds) : 745.0320720672607
Action : Sharding Passed Parameters
Traceback (most recent call last):
  File "/root/train.py", line 123, in <module>
    output = trainer.train(flax.core.FrozenDict({"params": params}))
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 478, in train
    sharded_state, shard_fns, gather_fns = self.initialize_state(
  File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 393, in initialize_state
    params = model_parameters if not self.arguments.do_shard_fns else jax.tree_util.tree_map(
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 243, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py", line 243, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Dict key mismatch; expected keys: ['transformer']; dict: {'transformer': {'wte': {'embedding':

To Reproduce

%%writefile /root/train.py

import os
import jax.numpy
import EasyDel
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wand_key = user_secrets.get_secret("WAND_KEY")
# os.environ["WANDB_DISABLED"] = "false"
os.environ["WANDB_API_KEY"] = wand_key

base_model_hf_repo_id_or_path = "gpt2"
max_length = 1024
trained_model_name = "****"
trained_model_hf_repo_id = f"****/{trained_model_name}"
easydel_trained_model_name = f"{trained_model_name}.easydel"
training_data_files="****"

import json
import sys
jcdataset = load_dataset('****', split='train')
f = open("./lmsys-toxic-gpt.json", "w")
for conversation in jcdataset['chunks']:
    out = "<|input|><|response|>"
    for req_res in conversation:
        out = out + req_res['prompt']
        f.write(json.dumps({'train': out}))
        f.write("\n")
        out = "<|input|>" + req_res['response']  +"<|response|>"

print("Downloading base model...")
model, params = AutoEasyDelModelForCausalLM.from_pretrained(base_model_hf_repo_id_or_path, trust_remote_code=True)

tokenizer = AutoTokenizer.from_pretrained(
    base_model_hf_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

model.config.use_sacn_mlp = False

print(type(model))
print(type(model.config))

train_arguments = TrainArguments(
    model_class=type(model),
    model_name=easydel_trained_model_name,
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_init_model_class,
    custom_rule=model.config.get_partition_rules(True),
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=8,
    max_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in fully FSDP automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)

def ultra_chat_prompting_process(
        data_chunk
):
    return {"prompt": data_chunk['train']}

tokenization_process = lambda data_chunk: tokenizer(
    data_chunk["prompt"],
    add_special_tokens=False,
    max_length=max_length,
    padding="max_length"
)

dataset = load_dataset("json", data_files=training_data_files)
dataset_train = dataset["train"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
    tokenization_process,
    num_proc=12,
    remove_columns=dataset_train.column_names
)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

import tempfile
import os
from huggingface_hub import Repository, create_repo
from transformers import LlamaForCausalLM, LlamaTokenizer
import jax

from EasyDel import (
    AutoEasyDelConfig,
    EasyDelState,
    easystate_to_huggingface_model
)

# Function to create a Hugging Face repository
def create_hf_repo(repo_name, hub_token=None):
    tmp_dir = tempfile.TemporaryDirectory()
    tmp_output_dir = tmp_dir.name

    if repo_name is None:
        repo_name = os.path.basename(tmp_output_dir)

    # Create repo and retrieve repo_id
    repo_id = create_repo(repo_name, exist_ok=True, token=hub_token).repo_id

    # Clone repo locally
    repo = Repository(tmp_output_dir, clone_from=repo_id, token=hub_token)
    tmp_dir.cleanup()

    return repo

# Define the base model ID, checkpoint path, and target Hugging Face repo ID
chkpoint_path = output.checkpoint_path

# Load configuration for the custom model
config = AutoEasyDelConfig.from_pretrained(base_model_hf_repo_id_or_path)

# Create the custom model using EasyDel
with jax.default_device(jax.devices("cpu")[0]):
    model = easystate_to_huggingface_model(
        state=EasyDelState.load_state(chkpoint_path),
        base_huggingface_module=LlamaForCausalLM,
        config=config
    )
# 
model = model.half()  # Convert to a Hugging Face model

# Check if the target Hugging Face repo exists, and create it if not
hub_token = None # login is already done
# repo = create_hf_repo(trained_model_hf_repo_id, hub_token)

# Optionally, you can push the base model to the target repo as well
base_model = LlamaForCausalLM.from_pretrained(base_model_hf_repo_id_or_path)
# base_model.push_to_hub(trained_model_hf_repo_id, token=hub_token)
tokenizer.push_to_hub(trained_model_hf_repo_id, token=hub_token)

# Push the custom model to the target Hugging Face repo
model.push_to_hub(trained_model_hf_repo_id, token=hub_token)
erfanzar commented 9 months ago

hello the issue with GPT2 is now fixed you can rerun your script

erfanzar commented 9 months ago

is the issue fixed?

erfanzar commented 9 months ago

this issue is being closed due to no response has been given