erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

NaN loss in ORPOTrainer with legacy_sharded_vanilla #156

Closed nyl199310 closed 3 weeks ago

nyl199310 commented 1 month ago

Hi, I tried to use ORPOTrainer to finetune a model. I found that if I use sharded_vanilla or other attention mechanisms, it reports memory resource exhausted issue but the loss stats are normal (it can only run first step then out of memory error).

Training: 1%| | 1/100 [01:26<2:23:09, 86.76s/it, epoch=0, learning_rate=1.98e-5, log_odds_chosen=5451.875, log_odds_ratio=0.0, logits/chosen=-1.4420047, logits/rejected=-1.5062321, logps/chosen=-111799.91, logps/rejected=-117251.78, loss=2.25, mean_loss=2.25, nll_loss=2.2457747, perplexity=9.45, rewards/accuracies=1.0, rewards/chosen=-11179.991, rewards/margins=545.1875, rewards/rejected=-11725.179, step=1, step_time=81.1]

Only if I use legacy_sharded_vanilla, there is no out of memory error, but all the loss stats are nan.

Training: 36%|███▌ | 36/100 [02:49<02:34, 2.41s/it, epoch=0, learning_rate=1.29e-5, log_odds_chosen=nan, log_odds_ratio=nan, logits/chosen=nan, logits/rejected=nan, logps/chosen=nan, logps/rejected=nan, loss=nan, mean_loss=nan, nll_loss=nan, perplexity=nan, rewards/accuracies=0.0, rewards/chosen=nan, rewards/margins=nan, rewards/rejected=nan, step=36, step_time=0.133]

erfanzar commented 1 month ago

hello and thanks for using EasyDeL

actually i told you that legacy_sharded_vanilla have a lot of miss computations and i recommended to don't use that, this attention only works good for AMD GPUs i don't really know why

but you can run

from easydel import AttentionModule
print(AttentionModule.test_attentions(axis_dims=(1,1,1,-1))) 
erfanzar commented 1 month ago

https://easydel.readthedocs.io/en/latest/attentionmodule_example.html

nyl199310 commented 1 month ago

Hi and thank you for your explanation. I may have given you the wrong idea. I first reduced the max_length then I tested three attention mechanism sharded_vanilla, local_ring, wise_ring, legacy_sharded_vanilla. All of them are nan loss stats (except first step). in addition, it's lora + orpo. below is my full code.

from easydel import (
    AutoEasyDeLModelForCausalLM,
    EasyDeLXRapTureConfig,
    AutoEasyDeLConfig,
    EasyDeLState,
    TrainArguments,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    SFTTrainer,
    ORPOTrainer,
    EasyDeLGradientCheckPointers,
    easystate_to_huggingface_model,
    get_modules_by_type
)
from datasets import load_dataset
from transformers import AutoTokenizer, LlamaForCausalLM, AutoConfig
from jax import numpy as jnp, lax
from flax.core import FrozenDict
import jax
import flax
from huggingface_hub import HfApi

huggingface_model_repo_id = "NousResearch/Meta-Llama-3-8B-Instruct"
max_length = 2048

model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    huggingface_model_repo_id,
    device=jax.devices('cpu')[0],
    input_shape=(1,2048),
    device_map="auto",
    sharding_axis_dims=(1, 1, 1, -1),
    config_kwargs=dict(
        use_scan_mlp=False,
        attn_mechanism='sharded_vanilla',
        max_length=2048
    ),
)

config = AutoEasyDeLConfig.from_pretrained(
    huggingface_model_repo_id
)

tokenizer = AutoTokenizer.from_pretrained(
    huggingface_model_repo_id,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token

configs_to_initialize_model_class = {
    "config": model.config,
    "dtype": jnp.bfloat16,
    "param_dtype": jnp.bfloat16,
    "input_shape": (1, max_length)
}

params = FrozenDict({"params": params})
rapture = EasyDeLXRapTureConfig(
    parameters=params,
    lora_dim=128,
    fully_fine_tune_parameters=[],  # Model layer to be fully fine tuned
    lora_fine_tune_parameters=["q_proj", "v_proj", "k_proj", "o_proj"],  # LoRA Layer Targets you can pass this to none
    # For only Layer Tuning or transfer learning
    verbose=True
)

train_arguments = TrainArguments(
    model_class=get_modules_by_type(model.config.model_type)[1],
    model_name="llama3",
    num_train_epochs=1,
    configs_to_initialize_model_class=configs_to_initialize_model_class,
    learning_rate=2e-5,
#     step_start_point=step_start_point,
    learning_rate_end=2e-7,
    optimizer=EasyDeLOptimizers.ADAMW,
    scheduler=EasyDeLSchedulers.LINEAR,
    weight_decay=0.01,
    #dataloader_num_workers=96,
    total_batch_size=1,
    max_training_steps=None,
    do_train=True,
    do_eval=False,
    backend="tpu",
    max_sequence_length=max_length,
    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, 1, 1, -1),
    init_input_shape=(1,max_length),
    remove_ckpt_after_load=True,
    gradient_accumulation_steps=1,
    training_time="8H",
    track_memory=True,
    neftune_noise_alpha=5.0,
    force_batch_and_gradient_accumulation_steps_calculation=True,
    loss_re_mat="",
    dtype=jnp.bfloat16,
    use_wandb=False,
    rapture_config=rapture,
    merge_lora_rapture_parameters=True
)

train_dataset = load_dataset("Intel/orca_dpo_pairs")['train']
desired_indices = range(0, 100)
train_dataset = train_dataset.select(desired_indices)
train_dataset = train_dataset.rename_column('question', 'prompt')

trainer = ORPOTrainer(
    arguments=train_arguments,
    max_length = 2048,
    max_prompt_length = 2048,
    max_completion_length = 2048,
    beta = 0.1,
    train_dataset=train_dataset,
    eval_dataset=None,
    tokenizer=tokenizer,
    low_mem_usage=True,
    dataset_num_proc=1
)

output = trainer.train()

Besides, I don't know if it's because the nan loss issue, after finetuning, the lora merging process report an issue like below.

Training:  99%|█████████▉| 99/100 [04:06<00:01,  1.32s/it, epoch=0, learning_rate=3.98e-7, log_odds_chosen=nan, log_odds_ratio=nan, logits/chosen=nan, logits/rejected=nan, logps/chosen=nan, logps/rejected=nan, loss=nan, mean_loss=nan, nll_loss=nan, perplexity=nan, rewards/accuracies=0.0, rewards/chosen=nan, rewards/margins=nan, rewards/rejected=nan, step=99, step_time=0.129]                                                       
Info :  Merging LoRA Parameters.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /usr/local/lib/python3.10/site-packages/fjformer/xrapture/xrapture.py:376, in XRapTure.merge_parameters.<locals>._ensure_delete(val)
    375 try:
--> 376     val.device_buffer.delete()
    377 except ValueError:

File /usr/local/lib/python3.10/site-packages/jax/_src/array.py:484, in ArrayImpl.device_buffer(self)
    483   return self._arrays[0]
--> 484 raise ValueError('Length of buffers is greater than 1. Please use '
    485                  '`.device_buffers` instead.')

ValueError: Length of buffers is greater than 1. Please use `.device_buffers` instead.

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
Cell In[1], line 152
    136 train_dataset = train_dataset.rename_column('question', 'prompt')
    139 trainer = ORPOTrainer(
    140     arguments=train_arguments,
    141     max_length = 2048,
   (...)
    149     dataset_num_proc=1
    150 )
--> 152 output = trainer.train()
    153 # output = trainer.train(flax.core.FrozenDict({"params": params}))
    154 
    155 
   (...)
    174 # config.push_to_hub("ivt1993/writer_llama3_8b_test", private=True, token='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')
    175 # print('done')

File /usr/local/lib/python3.10/site-packages/easydel/trainer/orpo/orpo_trainer.py:1082, in ORPOTrainer.train(self, model_parameters, state)
   1072 if self.arguments.merge_lora_rapture_parameters and self.rapture is not None:
   1073     print(
   1074         termcolor.colored(
   1075             "Info : ", color="red", force_color=True
   (...)
   1079         )
   1080     )
   1081     self.model_state = self.model_state.replace(
-> 1082         params=self.rapture.merge_parameters(self.model_state.params)
   1083     )
   1085 shard_fns, gather_fns = make_shard_and_gather_fns(
   1086     partition_specs=match_partition_rules(
   1087         rules=self.model_state.module.config.get_partition_rules(
   (...)
   1092     dtype_specs=self.arguments.dtype
   1093 )
   1094 output = ORPOTrainerOutput(
   1095     state=self.model_state,
   1096     mesh=self.mesh,
   (...)
   1099     checkpoint_manager=self.checkpoint_manager,
   1100 )

File /usr/local/lib/python3.10/site-packages/fjformer/xrapture/xrapture.py:390, in XRapTure.merge_parameters(lora_parameters, destructive)
    387         return result
    388     return param
--> 390 return tree_map_with_implicit(map_fn, lora_parameters)

File /usr/local/lib/python3.10/site-packages/fjformer/xrapture/implicit_array.py:643, in combine_leaf_predicate.<locals>.new_fn(new_is_leaf, *args)
    641     def combined_is_leaf(arg):
    642         return is_leaf(arg) or new_is_leaf(arg)
--> 643 return base_fn(*args, is_leaf=combined_is_leaf)

File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:244, in tree_map(f, tree, is_leaf, *rest)
    242 leaves, treedef = tree_flatten(tree, is_leaf)
    243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:244, in <genexpr>(.0)
    242 leaves, treedef = tree_flatten(tree, is_leaf)
    243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /usr/local/lib/python3.10/site-packages/fjformer/xrapture/xrapture.py:386, in XRapTure.merge_parameters.<locals>.map_fn(param)
    384     result = materialize(param)
    385     if destructive:
--> 386         jax.tree_map(_ensure_delete, param)
    387     return result
    388 return param

File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:244, in tree_map(f, tree, is_leaf, *rest)
    242 leaves, treedef = tree_flatten(tree, is_leaf)
    243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /usr/local/lib/python3.10/site-packages/jax/_src/tree_util.py:244, in <genexpr>(.0)
    242 leaves, treedef = tree_flatten(tree, is_leaf)
    243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File /usr/local/lib/python3.10/site-packages/fjformer/xrapture/xrapture.py:378, in XRapTure.merge_parameters.<locals>._ensure_delete(val)
    376     val.device_buffer.delete()
    377 except ValueError:
--> 378     val.device_buffers.delete()

AttributeError: 'list' object has no attribute 'delete'
erfanzar commented 1 month ago

thanks for reporting LoRA Issue, btw let me explain the process in attentionModule

attentions and attention module works very different in each scenario, each device, and each config

for example Splash, Flash Attention, Splash attention, Blockwise, which are the best supported modules are not working well on TPUs and how you can find best attention mechanism that works best for you

[!TIP] Following script are Kaggle-TPU-v3 with JAX 0.4.28 Results, and they are different on your device with different JAX versions

>ed.AttentionModule.test_attentions(axis_dims=(1,-1,1,1)) # FSDP Attention
METHOD OUT DIFF GRADIENT DIFF TEST PASSED COMP TIME
LOCAL_RING 4.5249023 0.11016822 False 6.365082
BLOCKWISE 2.142334 0.1539554 False 1.802229
VANILLA 0.0028076172 0.0055647814 True 0.019311
WISE_RING 1917.0201 31.526917 False 5.255283
SHARDED_VANILLA 0.0014648438 0.0 True 0.040799
LEGACY_SHARDED_VANILLA 0.0014648438 0.0 True 8.804201
FLASH 4.2730713 0.11256987 False 2.643749
SPLASH 8935.969 470.02725 False 4.505385
CUDNN NA NA NA NA
PALLAS_FLASH NA NA NA NA

as you can see in this case Legacy shared vanilla, vanilla, and shared vanilla works fine

but let change axis dims to 1,1,1,-1 or sequence sharding method

>ed.AttentionModule.test_attentions(axis_dims=(1,1,1,-1)) # sequence sharding Attention
METHOD OUT DIFF GRADIENT DIFF TEST PASSED COMP TIME
LOCAL_RING nan nan False 5.171325
BLOCKWISE 2.1427002 0.1539554 False 4.047265
VANILLA 0.0005493164 0.0 True 1.270251
WISE_RING nan nan False 5.366389
SHARDED_VANILLA 0.0005493164 0.0 True 3.149804
LEGACY_SHARDED_VANILLA nan nan False 8.098974
FLASH NA NA NA NA
SPLASH 8935.971 470.02725 False 5.481151
CUDNN NA NA NA NA
PALLAS_FLASH NA NA NA NA

and as you can see some attention output None in Sequence Sharding method and only Vanilla and Sharded Vanilla works here, but in case that your using other TPU versions or GPUs all of the attention works for you

erfanzar commented 1 month ago

@nyl199310 hello, is the issue fixed?

nyl199310 commented 1 month ago

@erfanzar Sorry, I was not able to test it. I cannot start the training with the latest code. it always stop after displaying below output with jax version=0.4.25.

/usr/local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
E0601 03:43:44.061192666    2836 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {created_time:"2024-06-01T03:43:44.0611721+00:00", grpc_status:2}
/usr/local/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
/usr/local/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Converting Model: 100%|██████████| 164/164 [01:22<00:00,  1.99it/s, missed_shardings=0]
/usr/local/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/usr/local/lib/python3.10/site-packages/easydel/trainer/training_configurations.py:388: UserWarning: setting `log_grad_norms` to off since using log grad norms while using LoRA is not Supported.
  warnings.warn(
Warning :  You are using LoRA (Low-Rank Adaptation of Large Language Models) and this feature isstill in Beta mode so it might act unexpected
Downloading readme: 100%|██████████| 196/196 [00:00<00:00, 1.14MB/s]
Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]
Downloading data:   0%|          | 0.00/36.3M [00:00<?, ?B/s]
Downloading data:  29%|██▉       | 10.5M/36.3M [00:00<00:01, 13.8MB/s]
Downloading data:  87%|████████▋ | 31.5M/36.3M [00:01<00:00, 32.1MB/s]
Downloading data: 100%|██████████| 36.3M/36.3M [00:01<00:00, 28.9MB/s]
Downloading data files: 100%|██████████| 1/1 [00:01<00:00,  1.34s/it]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 959.79it/s]
Generating train split: 12859 examples [00:00, 95186.13 examples/s]
Map: 100%|██████████| 100/100 [01:18<00:00,  1.27 examples/s]

if jax version = 0.4.28. it only display below output, then stop runing remaining code.

/usr/local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
/usr/local/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()

I'm still trying to figure out what's wrong with my code. it used to run normally.

erfanzar commented 1 month ago

It's due to kaggle environment changes they must have changed a lot of things in their entire environment, ill fix this asap

erfanzar commented 1 month ago

this should fix that

pip install -r https://raw.githubusercontent.com/erfanzar/EasyDeL/main/env_requirements.txt
erfanzar commented 4 weeks ago

this should fix that

pip install -r https://raw.githubusercontent.com/erfanzar/EasyDeL/main/env_requirements.txt

you don't need to do this anymore it's fixed in new easydel version 0.0.67