erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Training in kaggle's TPU is failing #117

Closed saidineshpola closed 4 months ago

saidineshpola commented 4 months ago

Describe the bug ValueError: Memory kinds passed to jax.jit does not match memory kind on the respective arg. Got pjit memory kind: tpu_hbm, arg memory kind: None for arg shape: float16[51200]

To Reproduce

!pip install datasets
!pip install git+https://github.com/erfanzar/EasyDeL.git
!pip install jax[tpu]==0.4.21 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

import os
# disable Weights and Biases
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    #BitsAndBytesConfig,
    HfArgumentParser,
    AutoTokenizer,
    GenerationConfig
)
from tqdm import tqdm
import time
import pandas as pd
import numpy as np
from functools import partial
from transformers import set_seed
from ast import literal_eval
# from huggingface_hub import interpreter_login

# interpreter_login()
IS_TPU=True

import jax.numpy
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer
print('Import Successfull')

# disable Weights and Biases
os.environ['WANDB_DISABLED']="true"
huggingface_repo_id_or_path = "microsoft/phi-2"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 2048
tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="my_first_model_to_train_using_easydel",
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=64,
    max_training_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, 1, 1, -1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, 1, 1, -1)
    # everything training will be in sequence and model parallel automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)

# def ultra_chat_prompting_process(
#         data_chunk
# ):
#     user_part = [
#         chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
#     ]
#     assistant_part = [
#         chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
#     ]

#     prompt = ""

#     for uc, ac in zip(user_part, assistant_part):
#         prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"

#     return {"prompt": prompt}

# tokenization_process = lambda data_chunk: tokenizer(
#     data_chunk["prompt"],
#     add_special_tokens=False,
#     max_length=max_length,
#     padding="max_length"
# )

# dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
# dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
# dataset_train = dataset_train.map(
#     tokenization_process,
#     num_proc=12,
#     remove_columns=dataset_train.column_names
# )

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    train_dataset,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}") 

Steps to reproduce the behavior

erfanzar commented 4 months ago

hello, and thanks for using EasyDeL! your code seems like to have some problems, ill try to run your code but make sure your using right version of jax

pip install jax[tpu]==0.21.0 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

0.21.0 does not exists we are currently on 0.4.25 / 26

saidineshpola commented 4 months ago

By bad i forgot to add 0.4.21 The default jax's 0.4.25 version in Kaggle is giving import errors so I used old one

erfanzar commented 4 months ago

cool but as I have shown in examples in the case of using Kaggle you should upgrade tensorflow too cause the newer JAX version needs newer TPU software and by upgrading tensorflow those TPU errors will be fixed

erfanzar commented 4 months ago

there are many tricks and things you must consider in the case of working with EasyDeL like training arguments like max_sequence_length, parition_spec, and ... i have modified your code, and here's the working version

pip install datasets
pip install git+https://github.com/erfanzar/EasyDeL.git
pip install jax[tpu]==0.4.23 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install tensorflow -U

import os
# disable Weights and Biases
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    #BitsAndBytesConfig,
    HfArgumentParser,
    AutoTokenizer,
    GenerationConfig
)
from tqdm import tqdm
import time
import pandas as pd
import numpy as np
from functools import partial
from transformers import set_seed
from ast import literal_eval
# from huggingface_hub import interpreter_login

# interpreter_login()
IS_TPU=True

import jax.numpy
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

# disable Weights and Biases
os.environ['WANDB_DISABLED']="true"
huggingface_repo_id_or_path = "microsoft/phi-2"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 2048
tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="my_first_model_to_train_using_easydel",
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    max_sequence_length=max_length,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=8,
    max_training_steps=None,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, 1, 1, -1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, 1, 1, -1)
    # everything training will be in sequence and model parallel automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=1,
    loss_re_mat="",
    dtype=jnp.bfloat16,
    use_wandb=False # This disable WANB usage
)

def ultra_chat_prompting_process(
        data_chunk
):
    user_part = [
        chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
    ]
    assistant_part = [
        chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
    ]

    prompt = ""

    for uc, ac in zip(user_part, assistant_part):
        prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"

    return {"prompt": prompt}

tokenization_process = lambda data_chunk: tokenizer(
    data_chunk["prompt"],
    add_special_tokens=False,
    max_length=max_length,
    padding="max_length"
)

dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
    tokenization_process,
    num_proc=12,
    remove_columns=dataset_train.column_names
)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}") 
saidineshpola commented 4 months ago

Thanks for the quick resolution over the weekend.