Describe the bug
While I was trying to finetune a GPT2 model and even another non-Llama model, I get the following exception. Am I missing something?
To Reproduce
Take a GPT2 model and train on some data
The following exception will be raised
Downloading data files: 100%|██████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9868.95it/s]
Extracting data files: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1449.81it/s]
Generating train split: 2942 examples [00:00, 241149.94 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████| 2942/2942 [00:00<00:00, 20119.71 examples/s]
Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: Tracking run with wandb version 0.16.2
wandb: W&B syncing is set to `offline` in this directory.
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
Time Took to Complete Task configure dataloaders (microseconds) : 0.20360946655273438
Time Took to Complete Task configure Model ,Optimizer ,Scheduler and Config (microseconds) : 654.585599899292
Traceback (most recent call last):
File "/home/***/research/EasyDeL/train_ravengpt.py", line 74, in <module>
trainer = CausalLanguageModelTrainer(
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 244, in __init__
self.init_functions()
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 310, in init_functions
funcs = self.configure_functions()
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 477, in configure_functions
create_sharded_state_from_params_fn = pjit(
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 762, in pjit
static_argnames) = pre_infer_params(
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 309, in pre_infer_params
in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 1093, in prepare_axis_resources
new_entries.append(ParsedPartitionSpec.from_user_input(
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 985, in from_user_input
raise TypeError(f"{arg_name} are expected to be "
TypeError: in_shardings leaf specifications are expected to be PartitionSpec instances or None, but got *
Traceback (most recent call last):
File "/home/***/research/EasyDeL/train_ravengpt.py", line 74, in <module>
trainer = CausalLanguageModelTrainer(
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 244, in __init__
self.init_functions()
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 310, in init_functions
funcs = self.configure_functions()
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 477, in configure_functions
create_sharded_state_from_params_fn = pjit(
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 762, in pjit
static_argnames) = pre_infer_params(
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 309, in pre_infer_params
in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 1093, in prepare_axis_resources
new_entries.append(ParsedPartitionSpec.from_user_input(
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 985, in from_user_input
raise TypeError(f"{arg_name} are expected to be "
TypeError: in_shardings leaf specifications are expected to be PartitionSpec instances or None, but got *
wandb: You can sync this run to the cloud by running:
Test Code
import json
import jax.numpy
from EasyDel import (
TrainArguments,
CausalLanguageModelTrainer,
AutoEasyDelModelForCausalLM,
EasyDelOptimizers,
EasyDelSchedulers,
EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer
huggingface_repo_id_or_path = "gpt2"
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )
max_length = 256
tokenizer = AutoTokenizer.from_pretrained(
huggingface_repo_id_or_path,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (1, 1)
}
train_arguments = TrainArguments(
model_class=type(model),
model_name="raven_gpt_using_easydel",
num_train_epochs=3,
configs_to_init_model_class=configs_to_init_model_class,
learning_rate=5e-5,
learning_rate_end=1e-6,
optimizer=EasyDelOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported
scheduler=EasyDelSchedulers.LINEAR,
# "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported
weight_decay=0.01,
total_batch_size=2,
max_steps=100, # None to let trainer Decide
do_train=True,
do_eval=False, # it's optional but supported
backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_length=max_length, # Note that you have to change this in the model config too
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, -1, 1, 1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
# everything training will be in fully FSDP automatic and share data between devices
use_pjit_attention_force=False,
remove_ckpt_after_load=True,
gradient_accumulation_steps=8,
loss_re_mat="",
dtype=jnp.bfloat16
)
with open("tmp_json.jsonl", "w") as tmp_json:
for line in open("data.jsonl", "r"):
record = json.loads(line.strip())
turn_resp = "<|input|>" + record["prompt"] + "<|response|>" + record["response"]
tmp_json.write(json.dumps({"turn_resp": turn_resp}))
data = load_dataset("json", data_files="tmp_json.jsonl", split="train")
data = data.map(lambda samples: tokenizer(samples["turn_resp"]), batched=True)
data = data.train_test_split(test_size=0.001)
# you can do the same for evaluation process dataset
trainer = CausalLanguageModelTrainer(
train_arguments,
data,
checkpoint_path=None
)
output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
Describe the bug While I was trying to finetune a GPT2 model and even another non-Llama model, I get the following exception. Am I missing something?
To Reproduce Take a GPT2 model and train on some data The following exception will be raised
Test Code