erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
192 stars 23 forks source link

While training Gpt2 model - Exception - TypeError: in_shardings leaf specifications are expected to be PartitionSpec instances or None, but got * #89

Closed jchauhan closed 8 months ago

jchauhan commented 8 months ago

Describe the bug While I was trying to finetune a GPT2 model and even another non-Llama model, I get the following exception. Am I missing something?

To Reproduce Take a GPT2 model and train on some data The following exception will be raised

Downloading data files: 100%|██████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9868.95it/s]
Extracting data files: 100%|███████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 1449.81it/s]
Generating train split: 2942 examples [00:00, 241149.94 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████| 2942/2942 [00:00<00:00, 20119.71 examples/s]
Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: Tracking run with wandb version 0.16.2
wandb: W&B syncing is set to `offline` in this directory.  
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
Time Took to Complete Task configure dataloaders (microseconds) : 0.20360946655273438
Time Took to Complete Task configure Model ,Optimizer ,Scheduler and Config (microseconds) : 654.585599899292
Traceback (most recent call last):
  File "/home/***/research/EasyDeL/train_ravengpt.py", line 74, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 244, in __init__
    self.init_functions()
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 310, in init_functions
    funcs = self.configure_functions()
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 477, in configure_functions
    create_sharded_state_from_params_fn = pjit(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 762, in pjit
    static_argnames) = pre_infer_params(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 309, in pre_infer_params
    in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 1093, in prepare_axis_resources
    new_entries.append(ParsedPartitionSpec.from_user_input(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 985, in from_user_input
    raise TypeError(f"{arg_name} are expected to be "
TypeError: in_shardings leaf specifications are expected to be PartitionSpec instances or None, but got *
Traceback (most recent call last):
  File "/home/***/research/EasyDeL/train_ravengpt.py", line 74, in <module>
    trainer = CausalLanguageModelTrainer(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 244, in __init__
    self.init_functions()
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 310, in init_functions
    funcs = self.configure_functions()
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 477, in configure_functions
    create_sharded_state_from_params_fn = pjit(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 762, in pjit
    static_argnames) = pre_infer_params(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 309, in pre_infer_params
    in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 1093, in prepare_axis_resources
    new_entries.append(ParsedPartitionSpec.from_user_input(
  File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/jax/_src/sharding_impls.py", line 985, in from_user_input
    raise TypeError(f"{arg_name} are expected to be "
TypeError: in_shardings leaf specifications are expected to be PartitionSpec instances or None, but got *
wandb: You can sync this run to the cloud by running:

Test Code

import json
import jax.numpy
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer

huggingface_repo_id_or_path = "gpt2"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

max_length = 256
tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, 1)
}

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="raven_gpt_using_easydel",
    num_train_epochs=3,
    configs_to_init_model_class=configs_to_init_model_class,
    learning_rate=5e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,  # "adamw", "lion", "adafactor" are supported
    scheduler=EasyDelSchedulers.LINEAR,
    # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear"  are supported
    weight_decay=0.01,
    total_batch_size=2,
    max_steps=100,  # None to let trainer Decide
    do_train=True,
    do_eval=False,  # it's optional but supported 
    backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
    max_length=max_length,  # Note that you have to change this in the model config too
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
    # everything training will be in fully FSDP automatic and share data between devices
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=8,
    loss_re_mat="",
    dtype=jnp.bfloat16
)

with open("tmp_json.jsonl", "w") as tmp_json:
    for line in open("data.jsonl", "r"):
        record = json.loads(line.strip())
        turn_resp = "<|input|>" + record["prompt"] + "<|response|>" + record["response"]
        tmp_json.write(json.dumps({"turn_resp": turn_resp}))

data = load_dataset("json", data_files="tmp_json.jsonl", split="train")
data = data.map(lambda samples: tokenizer(samples["turn_resp"]), batched=True)
data = data.train_test_split(test_size=0.001)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    data,
    checkpoint_path=None
)

output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
erfanzar commented 8 months ago

install easydel via

pip install git+https://github.com/erfanzar/EasyDeL.git -U --no-cache