erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
191 stars 23 forks source link

How to reduce TPU RAM when finetuning? #131

Closed IvoryTower800 closed 5 months ago

IvoryTower800 commented 6 months ago

Describe the bug Hi, I really appreciate your continued commitment to this project and make it better and better. I'm one of the people who benefit greatly. Thank you.

Now, I am trying to fine-tune the Yi-34B-Chat model using Kaggle's TPU but encounter insufficient memory errors.

When fine-tuning with 16-bit precision using transformers, QLora, and Flash Attention 2 on an A100 40G GPU, the process consumes about 33G of VRAM from my own experience.

Although a TPU VM v3-8 offers 128G of RAM, I'm unable to complete the fine-tune process due to memory constraints.

Previously, I fine-tuned the Yi-34B-Chat model extensively on an A100 40G, but I no longer have access to that machine. Now, I need to continue fine-tuning on a TPU v3-8 with the same lora parameters and sequence length.

Is it possible? Would greatly appreciate some tips on reducing memory usage to fit the constraints of the TPU.

To Reproduce

from flax.core import FrozenDict
from EasyDel import (
    TrainArguments,
    CausalLanguageModelTrainer,
    AutoEasyDelModelForCausalLM,
    EasyDelOptimizers,
    EasyDelSchedulers,
    EasyDelGradientCheckPointers,
    EasyDeLXRapTureConfig
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer
import jax

huggingface_repo_id_or_path = "/kaggle/input/yi-34b-chat"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )

model.config.rope_scaling =   {
    "factor": 2.0,
    "type": "linear"
  }

max_length = 8192
model.config.max_position_embeddings = max_length

model_parameters = FrozenDict({"params": params})

dtype = jnp.bfloat16
param_dtype = jnp.bfloat16

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_repo_id_or_path,
    trust_remote_code=True
)

model.config.add_basic_configurations(
    attn_mechanism="normal"
)

tokenizer.pad_token = tokenizer.eos_token
configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": dtype,
    "param_dtype": param_dtype,
    "input_shape": (1, max_length)
}

rapture = EasyDeLXRapTureConfig(
    parameters=model_parameters,
    lora_dim=128,
#     fully_fine_tune_parameters=[],  # Model layer to be fully fine tuned
    lora_fine_tune_parameters=["q_proj","o_proj","gate_proj","up_proj","down_proj","v_proj","k_proj"], 
    verbose=True
)

train_arguments = TrainArguments(
    model_class=type(model),
    model_name="EasyDeL-Lora-Example",
    num_train_epochs=1,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=1e-4,
    learning_rate_end=8e-5,
    optimizer=EasyDelOptimizers.ADAMW,
    scheduler=EasyDelSchedulers.LINEAR,
    weight_decay=0.01,
    total_batch_size=1,
    do_train=True,
    do_eval=False,
    backend="tpu",
    max_length=max_length,
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, 1, 1, -1),
    use_pjit_attention_force=False,
    remove_ckpt_after_load=True,
    offload_device = jax.devices("cpu")[0],
    gradient_accumulation_steps=1,
    loss_re_mat="",
    dtype=dtype,
    param_dtype=param_dtype,
    rapture_config=rapture,
    track_memory=True,
    init_input_shape=(1, max_length),
    merge_lora_rapture_parameters=True
)

def ultra_chat_prompting_process(
        data_chunk
):
    user_part = [
        chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
    ]
    assistant_part = [
        chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
    ]

    prompt = ""

    for uc, ac in zip(user_part, assistant_part):
        prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"

    return {"prompt": prompt}

tokenization_process = lambda data_chunk: tokenizer(
    data_chunk["prompt"],
    add_special_tokens=False,
    max_length=max_length,
    padding="max_length"
)

dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset["test_gen"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
    tokenization_process,
    num_proc=12,
    remove_columns=dataset_train.column_names
)

# you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer(
    train_arguments,
    dataset_train,
    checkpoint_path=None
)

output = trainer.train()
print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
Action : Sharding Passed Parameters
Model Contain 34.880437248 Billion Parameters
  0%|          | 0/28304 [00:00<?, ?it/s]/usr/local/lib/python3.10/site-packages/fjformer/xrapture/implicit_array.py:461: UserWarning: Primitive scan was not handled by class LoraWeight, so implicit args will be materialized.
  warnings.warn(
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[3], line 131
    123 # you can do the same for evaluation process dataset
    125 trainer = CausalLanguageModelTrainer(
    126     train_arguments,
    127     dataset_train,
    128     checkpoint_path=None
    129 )
--> 131 output = trainer.train()  # you should not pass the parameters in Trainer.train anymore when
    132 # you are using LoRA or transfer Learning
    133 print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:581, in CausalLanguageModelTrainer.train(self, model_parameters, state)
    579 for ssb in self.arguments.ids_to_pop_from_dataset:
    580     _ = batch.pop(ssb, None)
--> 581 sharded_state, loss, accuracy = self.sharded_train_step_function(
    582     sharded_state,
    583     batch
    584 )
    585 loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss
    586 accuracy_sum = accuracy.tolist() if accuracy_sum is None else accuracy_sum + accuracy

    [... skipping hidden 14 frame]

File /usr/local/lib/python3.10/site-packages/jax/_src/compiler.py:256, in backend_compile(backend, module, options, host_callbacks)
    251   return backend.compile(built_c, compile_options=options,
    252                          host_callbacks=host_callbacks)
    253 # Some backends don't have `host_callbacks` option yet
    254 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    255 # to take in `host_callbacks`
--> 256 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 20.25G of 15.48G hbm. Exceeded hbm capacity by 4.77G.

Total hbm usage >= 20.77G:
    reserved        530.00M 
    program          20.25G 
    arguments            0B 

Output size 0B; shares 0B with arguments.

Program hbm requirement 20.25G:
    global            1.31M
    scoped           50.83M
    HLO temp         20.20G (100.0% utilization: Unpadded (14.75G) Padded (14.75G), 27.0% fragmentation (5.45G))
erfanzar commented 6 months ago

Hi, Thanks for Using EasyDeL! sure here are some tips:

  1. use lower max_length
  2. use flash_attention
  3. use lower lora_dim
  4. change lora_fine_tune_parameters from ["q_proj","o_proj","gate_proj","up_proj","down_proj","v_proj","k_proj"] to ["q_proj", "o_proj", "v_proj", "k_proj"]
IvoryTower800 commented 6 months ago

Thank you for your suggestions. I found that sequence sharding can have longer contexts. So, I want to try sequence sharding with flash attention to maximum the sequence length.

I followed your instruction on the README. I got the ValueError: Attention bias shape mismatch: expected (batch_size=1, num_heads=8, q_seq_len=1024, kv_seq_len=1024), got (1, 8, 8192, 8192).

# use these partition specs in case of not using custom sharding_axis_names and using sequence sharding with flash flash attention
query_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "tp"),
generation_query_partition_spec=PartitionSpec(("dp", "fsdp"), None, None, "tp"),
key_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "tp"),
value_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", "tp"),
attention_partition_spec=PartitionSpec(("dp", "fsdp"), None,"sp", "tp"), 

So what's the correct partition strategy to use sequence sharding with flash attention?

erfanzar commented 6 months ago

Flash attention works on fsdp version that means you should use at least 8 batch size

IvoryTower800 commented 6 months ago

yes, I understand. But can flash attention works on sequence sharding using easydel?

I tried set batch size to 8 or more. it still says ValueError: Attention bias shape mismatch: expected (batch_size=1, num_heads=8, q_seq_len=1024, kv_seq_len=1024), got (1, 8, 8192, 8192).

When I set bias_partition_spec=PartitionSpec(("dp", "fsdp"), None, "sp", None),, it says ValueError: Attention bias shape mismatch: expected (batch_size=1, num_heads=8, q_seq_len=1024, kv_seq_len=1024), got (1, 8, 1024, 8192)

erfanzar commented 6 months ago

You have to set sharding array axis to 1,-1,1,1 For that and ill make a way or re create algorithms to make it possible using flash attention using sequence sharding method ( it's already possible but not on kaggle you simply have to set tensor parallel and sequence parallel to same number like 1,1,4,4)

IvoryTower800 commented 6 months ago

Thank you for your kind explaination. Will it be possible on kaggle in the near future?

erfanzar commented 6 months ago

yes soon it will be possible

and actually, the attention mechanism has improved and now it's faster and more efficient you can try that again by changing sharding array axis to 1,-1,1,1

IvoryTower800 commented 6 months ago

@erfanzar Thank you! I'm looking forward to it. I found the training speed increased about 25% for gemma model when using normal attention now. Amazing.