erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Out of Memory issue in new easydel version. #155

Closed nyl199310 closed 1 month ago

nyl199310 commented 1 month ago

Hi, I can run below code with previous easydel version without any problem.

!pip install git+https://github.com/erfanzar/EasyDeL.git@d06931e79cc3ef63920007d9e4f95fd0289df3cf # This version works well.
!pip install fjformer==0.0.51

but when I used latest easydel. it says:

XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 15.86G of 15.48G hbm. Exceeded hbm capacity by 382.92M.

from easydel import (
    AutoEasyDeLModelForCausalLM,
    AutoEasyDeLConfig,
    EasyDeLState,
    TrainArguments,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    SFTTrainer,
    ORPOTrainer,
    EasyDeLGradientCheckPointers,
    easystate_to_huggingface_model,
    get_modules_by_type
)
from datasets import load_dataset
from transformers import AutoTokenizer, LlamaForCausalLM, AutoConfig
from jax import numpy as jnp, lax
import jax
import flax
from huggingface_hub import HfApi

huggingface_model_repo_id = "NousResearch/Hermes-2-Pro-Llama-3-8B"
max_length = 8192

model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    'huggingface_model_repo_id',
    device=jax.devices('cpu')[0],
    input_shape=(1,8192),
    device_map="auto",
    sharding_axis_dims=(1, 1, 1, -1),
    config_kwargs=dict(
        use_scan_mlp=False,
        attn_mechanism='sharded_vanilla',
    ),
)

config = AutoEasyDeLConfig.from_pretrained(
    huggingface_model_repo_id
)

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_model_repo_id,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token

configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, max_length)
}

train_arguments = TrainArguments(
    model_class=get_modules_by_type(model.config.model_type)[1],
    model_name="llama3",
    num_train_epochs=1,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=2e-5,
#     step_start_point=step_start_point,
    learning_rate_end=2e-7,
    optimizer=EasyDeLOptimizers.ADAMW,
    scheduler=EasyDeLSchedulers.LINEAR,
    weight_decay=0.01,
    #dataloader_num_workers=96,
    total_batch_size=1,
    max_training_steps=None,
    do_train=True,
    do_eval=False,
    backend="tpu",
    max_sequence_length=max_length,
    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, 1, 1, -1),
    init_input_shape=(1,max_length),
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=3,
    training_time="8H",
    track_memory=True,
    neftune_noise_alpha=5.0,
    force_batch_and_gradient_accumulation_steps_calculation=True,
    loss_re_mat="",
    dtype=jnp.bfloat16
)

train_dataset = load_dataset('csv',data_files="/kaggle/input/insert-p1/insert_p1.csv")['train']
desired_indices = range(0, 200)
train_dataset = train_dataset.select(desired_indices)

trainer = SFTTrainer(
    arguments=train_arguments,
    train_dataset=train_dataset,
    eval_dataset=None,
    tokenizer=tokenizer,
    dataset_text_field="text",
    dataset_num_proc=96,
    packing=False,
)
erfanzar commented 1 month ago

hello and thanks for using EasyDeL It's likely due to recent changes in attention mechanism to fix miss computation problems. I'll test the same code and try to find the issue and fix that.

nyl199310 commented 1 month ago

Thank you so much @erfanzar . and there is another issue. When using the ORPOTrainer, the tokenize speed is very slow. about 1~2 examples per second. There isn't a parameter like SFTTrainer e.g, dataset_num_proc=96. the same hardware can achieve about 3000 examples per second in SFTTrainer.

erfanzar commented 1 month ago

ORPOTrainer support dataset_map_arguments which is a dict that will be passed to Dataset.map, but anyway I added dataset_num_proc to it for you.

nyl199310 commented 1 month ago

@erfanzar Thank you!

erfanzar commented 1 month ago

@nyl199310 you can use legacy_sharded_vanilla for old attention but that one have a lot of miss computations in different devices

nyl199310 commented 1 month ago

Hi, there is a bug after adding the dataset_num_proc parameter.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 7
      3 train_dataset = train_dataset.select(desired_indices)
      4 train_dataset = train_dataset.rename_column('question', 'prompt')
----> 7 trainer = ORPOTrainer(
      8     arguments=train_arguments,
      9     max_length = 8192,
     10     max_prompt_length = 8192,
     11     max_completion_length = 2048,
     12     beta = 0.1,
     13     train_dataset=train_dataset,
     14     eval_dataset=None,
     15     tokenizer=tokenizer,
     16     low_mem_usage=True,
     17 )
     19 output = trainer.train(flax.core.FrozenDict({"params": params}))

File /usr/local/lib/python3.10/site-packages/easydel/trainer/orpo/orpo_trainer.py:168, in ORPOTrainer.__init__(self, arguments, max_length, max_prompt_length, max_completion_length, beta, disable_dropout, label_pad_token_id, is_encoder_decoder, padding_value, data_collator, train_dataset, eval_dataset, tokenizer, dataset_num_proc, _do_init_fns, dataset_map_arguments, low_mem_usage)
    166 if dataset_map_arguments is None:
    167     dataset_map_arguments = {}
--> 168 train_dataset = train_dataset.map(
    169     self.tokenize_row,
    170     dataset_num_proc=dataset_num_proc,
    171     **dataset_map_arguments
    172 )
    173 if eval_dataset is not None:
    174     eval_dataset = eval_dataset.map(
    175         self.tokenize_row,
    176         dataset_num_proc=dataset_num_proc,
    177         **dataset_map_arguments
    178     )

File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:592, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
    590     self: "Dataset" = kwargs.pop("self")
    591 # apply actual function
--> 592 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    593 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    594 for dataset in datasets:
    595     # Remove task templates if a column mapping of the template is no longer valid

File /usr/local/lib/python3.10/site-packages/datasets/arrow_dataset.py:557, in transmit_format.<locals>.wrapper(*args, **kwargs)
    550 self_format = {
    551     "type": self._format_type,
    552     "format_kwargs": self._format_kwargs,
    553     "columns": self._format_columns,
    554     "output_all_columns": self._output_all_columns,
    555 }
    556 # apply actual function
--> 557 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    558 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    559 # re-apply format to the output

TypeError: Dataset.map() got an unexpected keyword argument 'dataset_num_proc'