Closed saidineshpola closed 4 months ago
hello, and thanks for using EasyDeL! your code seems like to have some problems, ill try to run your code but make sure your using right version of jax
pip install jax[tpu]==0.21.0 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
By bad i forgot to add 0.4.21 The default jax's 0.4.25 version in Kaggle is giving import errors so I used old one
cool but as I have shown in examples in the case of using Kaggle you should upgrade tensorflow too cause the newer JAX version needs newer TPU software and by upgrading tensorflow those TPU errors will be fixed
there are many tricks and things you must consider in the case of working with EasyDeL like training arguments like max_sequence_length, parition_spec, and ... i have modified your code, and here's the working version
pip install datasets
pip install git+https://github.com/erfanzar/EasyDeL.git
pip install jax[tpu]==0.4.23 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install tensorflow -U
import os
# disable Weights and Biases
from datasets import load_dataset
from transformers import (
AutoTokenizer,
#BitsAndBytesConfig,
HfArgumentParser,
AutoTokenizer,
GenerationConfig
)
from tqdm import tqdm
import time
import pandas as pd
import numpy as np
from functools import partial
from transformers import set_seed
from ast import literal_eval
# from huggingface_hub import interpreter_login
# interpreter_login()
IS_TPU=True
import jax.numpy
from EasyDel import (
TrainArguments,
CausalLanguageModelTrainer,
AutoEasyDelModelForCausalLM,
EasyDelOptimizers,
EasyDelSchedulers,
EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer
# disable Weights and Biases
os.environ['WANDB_DISABLED']="true"
huggingface_repo_id_or_path = "microsoft/phi-2"
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )
max_length = 2048
tokenizer = AutoTokenizer.from_pretrained(
huggingface_repo_id_or_path,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (1, 1)
}
train_arguments = TrainArguments(
model_class=type(model),
model_name="my_first_model_to_train_using_easydel",
num_train_epochs=3,
configs_to_initialize_model_class=configs_to_initialize_model_class,
learning_rate=5e-5,
learning_rate_end=1e-6,
max_sequence_length=max_length,
optimizer=EasyDelOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported
scheduler=EasyDelSchedulers.LINEAR,
# "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported
weight_decay=0.01,
total_batch_size=8,
max_training_steps=None, # None to let trainer Decide
do_train=True,
do_eval=False, # it's optional but supported
backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_length=max_length, # Note that you have to change this in the model config too
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, 1, 1, -1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, 1, 1, -1)
# everything training will be in sequence and model parallel automatic and share data between devices
use_pjit_attention_force=False,
remove_ckpt_after_load=True,
gradient_accumulation_steps=1,
loss_re_mat="",
dtype=jnp.bfloat16,
use_wandb=False # This disable WANB usage
)
def ultra_chat_prompting_process(
data_chunk
):
user_part = [
chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
]
assistant_part = [
chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
]
prompt = ""
for uc, ac in zip(user_part, assistant_part):
prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"
return {"prompt": prompt}
tokenization_process = lambda data_chunk: tokenizer(
data_chunk["prompt"],
add_special_tokens=False,
max_length=max_length,
padding="max_length"
)
dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
tokenization_process,
num_proc=12,
remove_columns=dataset_train.column_names
)
# you can do the same for evaluation process dataset
trainer = CausalLanguageModelTrainer(
train_arguments,
dataset_train,
checkpoint_path=None
)
output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
Thanks for the quick resolution over the weekend.
Describe the bug ValueError: Memory kinds passed to jax.jit does not match memory kind on the respective arg. Got pjit memory kind: tpu_hbm, arg memory kind: None for arg shape: float16[51200]
To Reproduce
Steps to reproduce the behavior