erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

'LoraWeight' object has no attribute 'tolist' #145

Closed defdet closed 2 months ago

defdet commented 2 months ago

Describe the bug When trying to fine tune Qwen 32B with Lora on v4-32 TPU pod, this error occurs. Full traceback in the next message for the ease of reading. To Reproduce Full code:

from huggingface_hub import hf_hub_download
import jax.numpy
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer
from flax.core import FrozenDict
from datasets import Dataset

hf_hub_download(repo_id="Defetya/saiga-ds-merged",filename="train.json", local_dir='.', repo_type='dataset')
hf_hub_download(repo_id="Defetya/saiga-ds-merged",filename="test.json", local_dir='.', repo_type='dataset')

train_ds = Dataset.from_json('train.json')
test_ds = Dataset.from_json('test.json')

from EasyDel import (
    TrainArguments,
    SFTTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers,
    conversations_formatting_function,
    EasyDeLXRapTureConfig,
)
import EasyDel
print(EasyDel.__version__)

device_num = len(jax.devices())
huggingface_repo_id_or_path = "/mnt/disks/huggingface_checkpoints"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, input_shape=(device_num, 1))
model_parameters = FrozenDict({"params" : params})

max_length = 1024
tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)
if tokenizer.pad_token == None:
    tokenizer.pad_token = tokenizer.eos_token

lengths = []
from tqdm import tqdm
for sample in tqdm(train_ds):
    msg = sample['messages']
    length = 0
    for turn in msg:
        length += len(turn['content'])
    lengths.append(length)
print(max(lengths), sum(lengths) / len(lengths))
model.config.add_basic_configurations(
    attn_mechanism="ring",
    block_b=1,
    block_q=128,
    block_k=128,
    block_k_major=128,
)

configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, max_length)
}
rapture = EasyDeLXRapTureConfig(
    parameters=model_parameters,
    lora_dim=16,
    lora_fine_tune_parameters=["q_proj", "o_proj"],  # LoRA Layer Targets you can pass this to none
    # For only Layer Tuning or transfer learning
    verbose=True,
    dtype=jax.numpy.bfloat16,
)

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="fine_tuned_qwen_32",
    num_train_epochs=3,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=8e-5,
    learning_rate_end=1e-5,
    optimizer=EasyDelOptimizers.ADAMW,
    scheduler=EasyDelSchedulers.LINEAR,
    weight_decay=0.01,
    total_batch_size=1,
    max_training_steps=None,
    do_train=True,
    do_eval=False,
    backend="tpu",
    max_sequence_length=max_length,
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, 1, 4, 4),
    init_input_shape=(1, max_length),
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=1,
    loss_re_mat="",
    rapture_config=rapture,
    merge_lora_rapture_parameters=True,
    dtype=jnp.bfloat16,

)
trainer = SFTTrainer(
    arguments=train_arguments,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    dataset_text_field=None,
#     formatting_func=prompter,
    formatting_func=lambda x:[conversations_formatting_function(tokenizer, messages_field="messages")(x)],
    packing=False,
)

output = trainer.train()
defdet commented 2 months ago

Full traceback:

File "/home/Username/train.py", line 118, in <module>
    output = trainer.train()
  File "/home/Username/.local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py", line 520, in train
    train_metrics.update({
  File "/home/Username/.local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer/causal_language_model_trainer.py", line 521, in <dictcomp>
    f"grad_norm/{layer_name}": grad_norm.tolist()
AttributeError: 'LoraWeight' object has no attribute **'tolist'**
defdet commented 2 months ago

Latest versions of JFFormer (0.0.50) as well as EasyDeL (0.0.61)

erfanzar commented 2 months ago

there was an issue from the last pull request, I have fixed that can you try again?

defdet commented 2 months ago

cheers seems to work