erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Training with Ring Attention Failed #120

Closed IvoryTower800 closed 3 months ago

IvoryTower800 commented 3 months ago

Describe the bug Hi, I ran below code on Kaggle's tpu vm v3-8. when i set the attn_mechanism to "normal", it worked well. However, when I changed the attn_mechanism to ring. below error raised. Could you please guide me how to fix it?

To Reproduce

!pip install datasets !pip install git+https://github.com/erfanzar/EasyDeL.git !pip install jax[tpu]==0.4.23 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html !pip install tensorflow -U

import os

disable Weights and Biases

from datasets import load_dataset from transformers import ( AutoTokenizer,

BitsAndBytesConfig,

HfArgumentParser,
AutoTokenizer,
GenerationConfig

) from tqdm import tqdm import time import pandas as pd import numpy as np from functools import partial from transformers import set_seed from ast import literal_eval

from huggingface_hub import interpreter_login

interpreter_login()

IS_TPU=True

import jax.numpy from EasyDel import ( TrainArguments, CausalLanguageModelTrainer, AutoEasyDelModelForCausalLM, EasyDelOptimizers, EasyDelSchedulers, EasyDelGradientCheckPointers ) from datasets import load_dataset import flax from jax import numpy as jnp from transformers import AutoTokenizer

disable Weights and Biases

os.environ['WANDB_DISABLED']="true" huggingface_repo_id_or_path = "microsoft/phi-2"

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,sharding_axis_dims=(1, 1, 1, -1) ) model.config.add_basic_configurations( attn_mechanism="ring", # Using Flash Attention here you can simply just set this to normal or ring block_b=1, block_q=128, block_k=128, block_k_major=128, ) max_length = 2048 tokenizer = AutoTokenizer.from_pretrained( huggingface_repo_id_or_path, trust_remote_code=True ) tokenizer.pad_token = tokenizer.eos_token configs_to_initialize_model_class = { "config": model.config, "dtype": jnp.bfloat16, "param_dtype": jnp.bfloat16, "input_shape": (8, 1024) }

train_arguments = TrainArguments( model_class=type(model), model_name="my_first_model_to_train_using_easydel", num_train_epochs=3, configs_to_initialize_model_class=configs_to_initialize_model_class, learning_rate=5e-5, learning_rate_end=1e-6, max_sequence_length=max_length, optimizer=EasyDelOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported scheduler=EasyDelSchedulers.LINEAR,

"linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported

weight_decay=0.01,
total_batch_size=8,
max_training_steps=None,  # None to let trainer Decide
do_train=True,
do_eval=False,  # it's optional but supported 
backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_length=max_length,  # Note that you have to change this in the model config too
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, 1, 1, -1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, 1, 1, -1)
# everything training will be in sequence and model parallel automatic and share data between devices
use_pjit_attention_force=False,
remove_ckpt_after_load=True,
gradient_accumulation_steps=1,
loss_re_mat="",
dtype=jnp.bfloat16,
use_wandb=False # This disable WANB usage

)

def ultra_chat_prompting_process( data_chunk ): user_part = [ chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user" ] assistant_part = [ chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant" ]

prompt = ""

for uc, ac in zip(user_part, assistant_part):
    prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"

return {"prompt": prompt}

tokenization_process = lambda data_chunk: tokenizer( data_chunk["prompt"], add_special_tokens=False, max_length=max_length, padding="max_length" )

dataset = load_dataset("HuggingFaceH4/ultrachat_200k") dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=os.cpu_count()) dataset_train = dataset_train.map( tokenization_process, num_proc=os.cpu_count(), remove_columns=dataset_train.column_names )

you can do the same for evaluation process dataset

trainer = CausalLanguageModelTrainer( train_arguments, dataset_train, checkpoint_path=None )

output = trainer.train(flax.core.FrozenDict({"params": params})) print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

/usr/local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm E0307 14:46:35.492612126 5786 oauth2_credentials.cc:238] oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {created_time:"2024-03-07T14:46:35.492596286+00:00", grpc_status:2} /usr/local/lib/python3.10/site-packages/pydantic/_internal/_fields.py:149: UserWarning: Field "modelname" has conflict with protected namespace "model".

You may be able to resolve this warning by setting model_config['protected_namespaces'] = (). warnings.warn( 2024-03-07 14:46:43.224169: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-03-07 14:46:43.224227: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-03-07 14:46:43.226292: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.02it/s] Converting Model: 100%|██████████| 453/453 [00:25<00:00, 18.11it/s] Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. Information : track_memory is set to False by default inorder make make training faster. you can turn it on with just passing track_memory=True in TrainArguments /usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'. table = cls._concat_blocks(blocks, axis=0) Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parameters in train function Time Took to Complete Task configure dataloaders (microseconds) : 0.31828880310058594

AssertionError Traceback (most recent call last) Cell In[1], line 129 121 dataset_train = dataset_train.map( 122 tokenization_process, 123 num_proc=os.cpu_count(), 124 remove_columns=dataset_train.column_names 125 ) 127 # you can do the same for evaluation process dataset --> 129 trainer = CausalLanguageModelTrainer( 130 train_arguments, 131 dataset_train, 132 checkpoint_path=None 133 ) 135 output = trainer.train(flax.core.FrozenDict({"params": params})) 136 print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:144, in BaseTrainer.init(self, arguments, dataset_train, dataset_eval, finetune, checkpoint_path, _do_init_fns) 138 prefix_print( 139 "Warning", 140 "In case of using finetune = True and Passing checkpoint_path = None" 141 " you should pass parameters in train function" 142 ) 143 if _do_init_fns: --> 144 self.initialize_trainer_utils() 145 else: 146 prefix_print( 147 "Warning", 148 "you have set _do_init_fns = False so function will not me initialized you have " 149 f"to do in manually (simply with trainer.initialize_trainer_utils() )" 150 )

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:203, in BaseTrainer.initialize_trainer_utils(self) 200 self.timer.log(["configure dataloaders"]) 202 self.timer("configure Model, Optimizer, Scheduler and Config").start() --> 203 model_configurations = self.configure_model() 204 model = model_configurations.model 205 tx = model_configurations.tx

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:325, in BaseTrainer.configure_model(self) 319 self.arguments.configs_to_initialize_model_class[ 320 "config" 321 ].use_pjit_attention_force = self.arguments.use_pjit_attention_force 323 self.arguments.configs_to_initialize_model_class["config"].axis_dims = self.arguments.sharding_array --> 325 model = self.arguments.model_class( 326 **self.arguments.configs_to_initialize_model_class, 327 _do_init=False 328 ) 330 config = self.arguments.configs_to_initialize_model_class["config"] 332 else:

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:755, in FlaxPhiPreTrainedModel.init(self, config, dtype, param_dtype, precision, input_shape, seed, _do_init) 739 def init( 740 self, 741 config: PhiConfig, (...) 747 _do_init: bool = False 748 ) -> None: 749 module = self.module_class( 750 config=config, 751 dtype=dtype, 752 param_dtype=param_dtype, 753 precision=precision 754 ) --> 755 super().init( 756 config=config, 757 module=module, 758 input_shape=input_shape, 759 _do_init=_do_init, 760 seed=seed 761 )

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/easydel_modelling_utils.py:354, in EasyDelFlaxPretrainedModel.init(self, config, module, input_shape, seed, dtype, param_dtype, precision, _do_init) 343 def init( 344 self, 345 config: PretrainedConfig, (...) 352 _do_init: bool = True, 353 ): --> 354 super().init( 355 config=config, 356 module=module, 357 input_shape=input_shape, 358 seed=seed, 359 dtype=dtype, 360 _do_init=_do_init 361 )

File /usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py:223, in FlaxPreTrainedModel.init(self, config, module, input_shape, seed, dtype, _do_init) 221 else: 222 init_fn = partial(self.init_weights, input_shape=input_shape) --> 223 params_shape_tree = jax.eval_shape(init_fn, self.key) 225 logger.info( 226 "Model weights are not initialized as _do_init is set to False. " 227 f"Make sure to call {self.__class__.__name__}.init_weights manually to initialize the weights." 228 ) 230 # get the shape of the parameters

[... skipping hidden 13 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:781, in FlaxPhiPreTrainedModel.init_weights(self, rng, input_shape, params) 778 params_rng, dropout_rng = jax.random.split(rng) 779 rngs = {"params": params_rng, "dropout": dropout_rng} --> 781 module_init_outputs = self.module.init(rngs, input_ids, attention_mask) 783 random_params = module_init_outputs["params"] 785 if params is not None:

[... skipping hidden 9 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:706, in FlaxPhiForCausalLMModule.call(self, input_ids, inputs_embeds, attention_mask, position_ids, extra_embedding, deterministic, output_attentions, output_hidden_states, init_cache, return_dict) 693 def call( 694 self, 695 input_ids: Optional[chex.Array] = None, (...) 704 return_dict: bool = True, 705 ) -> tuple[Any, ...] | FlaxMaskedLMOutput: --> 706 res = self.model( 707 input_ids=input_ids, 708 attention_mask=attention_mask, 709 init_cache=init_cache, 710 deterministic=deterministic, 711 extra_embedding=extra_embedding, 712 position_ids=position_ids, 713 output_attentions=output_attentions, 714 output_hidden_states=output_hidden_states, 715 return_dict=True 716 ) 717 outputs = (res.last_hidden_state, res.hidden_states, res.attentions) 718 if self.config.tie_word_embeddings:

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:638, in FlaxPhiModule.call(self, input_ids, inputs_embeds, attention_mask, position_ids, extra_embedding, deterministic, output_attentions, output_hidden_states, init_cache, return_dict) 635 assert sequence_length <= self.config.max_position_embeddings, "Maximum Position Embedding Reached !" 636 inputs_embeds = inputs_embeds + extra_embedding if extra_embedding is not None else inputs_embeds --> 638 outputs = self.layers( 639 hidden_states=inputs_embeds, 640 freq_cis=self.freq_cis, 641 attention_mask=attention_mask, 642 position_ids=position_ids, 643 causal_mask=self.causal_mask, 644 deterministic=deterministic, 645 init_cache=init_cache, 646 output_attentions=output_attentions, 647 output_hidden_states=output_hidden_states, 648 return_dict=return_dict, 649 ) 651 hidden_states = outputs[0] 652 hidden_states = self.final_layernorm(hidden_states)

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:528, in FlaxPhiDecoderLayerCollection.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, output_attentions, output_hidden_states, init_cache, return_dict) 518 all_hidden_states += (hidden_states,) 520 # hidden_states: chex.Array, 521 # freq_cis: Tuple[chex.Array, chex.Array], 522 # attention_mask: Optional[chex.Array], (...) 526 # output_attentions: bool = False, 527 # init_cache: bool = False, --> 528 layer_outputs = decoder_layer( 529 hidden_states, 530 freq_cis, 531 attention_mask, 532 position_ids, 533 causal_mask, 534 deterministic, 535 output_attentions, 536 init_cache, 537 ) 539 hidden_states = layer_outputs[0] 541 if output_attentions:

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:553, in core_remat_static..inner(scope_fn, repack_fn, variable_groups, rng_groups, args) 550 y = fn(scope, args) 551 return y, repack_fn(scope) --> 553 return rematted(variable_groups, rng_groups, *dyn_args)

[... skipping hidden 7 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:550, in core_remat_static..inner..rematted(variable_groups, rng_groups, dyn_args) 548 args = _repack_remat_args(dyn_args, static_args) 549 scope = scope_fn(variable_groups, rng_groups) --> 550 y = fn(scope, args) 551 return y, repack_fn(scope)

[... skipping hidden 3 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:425, in FlaxPhiDecoderLayer.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, output_attentions, init_cache) 422 residual = hidden_states 423 hidden_states = self.input_layernorm(hidden_states) --> 425 attn_out = self.self_attn( 426 hidden_states=hidden_states, 427 attention_mask=attention_mask, 428 position_ids=position_ids, 429 output_attentions=output_attentions, 430 deterministic=deterministic, 431 freq_cis=freq_cis, 432 causal_mask=causal_mask, 433 init_cache=init_cache, 434 ) 435 attn_outputs, self_attn_weights = (attn_out[0], attn_out[1]) if len(attn_out) == 2 else (attn_out[0], None) 437 attn_outputs = self.resid_dropout(attn_outputs, deterministic=deterministic)

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:361, in FlaxPhiAttention.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, output_attentions, init_cache) 352 attention_bias = lax.select( 353 attention_mask > 0, 354 jnp.full(attention_mask.shape, 0.0).astype(self.dtype), 355 jnp.full(attention_mask.shape, jnp.finfo( 356 self.dtype).min).astype(self.dtype), 357 ) 359 query_length, key_length = query_states.shape[1], key_states.shape[1] --> 361 attentions = self.attention_performer.call( 362 query_states=query_states, 363 key_states=key_states, 364 value_states=value_states, 365 bias=attention_bias, 366 causal=True, 367 use_pjit_attention_force=self.config.use_pjit_attention_force, 368 dropout_rng=dropout_rng, 369 deterministic=deterministic, 370 query_sequence_length=query_length, 371 key_value_sequence_length=key_length, 372 uses_cache=self.has_variable("cache", "cached_key") or init_cache, 373 ) 374 attentions.attention_outputs = attentions.attention_outputs 375 attn_output = self._merge_heads(attentions.attention_outputs)

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py:219, in EasyAttention.call(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, segment_ids, causal, deterministic, dropout_rng, use_pjit_attention_force, uses_cache) 206 attentions = self._qkv_normal_op( 207 query_states=query_states, 208 key_states=key_states, (...) 216 key_value_sequence_length=key_value_sequence_length, 217 ) 218 elif self.attn_mechanism == "ring": --> 219 attentions = self._qkv_ring_op( 220 query_states=query_states, 221 key_states=key_states, 222 value_states=value_states, 223 bias=bias, 224 dropout_rng=dropout_rng, 225 use_pjit_attention_force=use_pjit_attention_force, 226 causal=causal, 227 deterministic=deterministic, 228 query_sequence_length=query_sequence_length, 229 key_value_sequence_length=key_value_sequence_length, 230 segment_ids=segment_ids, 231 ) 233 elif self.attn_mechanism == "splash": 234 raise NotImplementedError("Splash Attention is not Implemented YET!")

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py:296, in EasyAttention._qkv_ring_op(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, causal, deterministic, dropout_rng, use_pjit_attention_force, segment_ids) 266 ring_attention_fn = ring_attention 267 ring_attention_sharded = shard_map( 268 partial( 269 ring_attention_fn, (...) 294 check_rep=False 295 ) --> 296 attn_output = ring_attention_sharded(query_states, key_states, value_states, bias, segment_ids) 297 attn_output = with_sharding_constraint(attn_output, self.attention_partition_spec) 298 else:

[... skipping hidden 12 frame]

File /usr/local/lib/python3.10/site-packages/fjformer/pallas_operations/ring_attention/ring_attention.py:561, in ring_flash_attention_tpu(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwise_kwargs) 559 @partial(jax.custom_vjp, nondiff_argnums=[5, 6, 7]) 560 def ring_flash_attention_tpu(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwisekwargs): --> 561 y, = _ring_flash_attention_fwd_tpu(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwise_kwargs) 562 return y

File /usr/local/lib/python3.10/site-packages/fjformer/pallas_operations/ring_attention/ring_attention.py:474, in _ring_flash_attention_fwd_tpu(q, k, v, attn_bias, segment_ids, axis_name, float32_logits, blockwise_kwargs) 470 k, v = map(lambda x: lax.ppermute(x, axis_name, perm=[(i, (i + 1) % axis_size) for i in range(axissize)]), 471 (k, v)) 472 return (o, l, m, k, v), None --> 474 (o, l, m, , ), = lax.scan(scan_kv_block, 475 init=(o, l, m, k, v), xs=jnp.arange(0, axis_size)) 476 output = rearrange(o.astype(v.dtype), 'b h q d -> b q h d') 477 return output, (o, q, k, v, attn_bias, segment_ids, l, m)

[... skipping hidden 9 frame]

File /usr/local/lib/python3.10/site-packages/fjformer/pallas_operations/ring_attention/ring_attention.py:457, in _ring_flash_attention_fwd_tpu..scan_kv_block(carry, idx) 455 q_chunk_idx_start = q_block_idx (q_block_size // query_chunk_size) 456 k_chunk_idx_start = k_block_idx (kv_block_size // key_chunk_size) --> 457 o, l, m = _flash_attention_fwd( 458 q, k, v, 459 carry=(o, l, m), 460 q_chunk_idx_start=q_chunk_idx_start, 461 k_chunk_idx_start=k_chunk_idx_start, 462 ab=attn_bias_slice, 463 segment_ids=segment_ids_slice, 464 save_residuals=False, 465 causal=blockwise_kwargs["causal"], 466 sm_scale=scale, 467 block_sizes=block_sizes, 468 debug=False 469 ) 470 k, v = map(lambda x: lax.ppermute(x, axis_name, perm=[(i, (i + 1) % axis_size) for i in range(axis_size)]), 471 (k, v)) 472 return (o, l, m, k, v), None

File /usr/local/lib/python3.10/site-packages/fjformer/pallas_operations/ring_attention/ring_attention.py:714, in _flash_attention_fwd(q, k, v, carry, q_chunk_idx_start, k_chunk_idx_start, ab, segment_ids, save_residuals, causal, sm_scale, block_sizes, debug) 712 if save_residuals: 713 raise NotImplementedError("Higher-order AD not supported") --> 714 o, l, m = _flash_attention( 715 q, 716 k, 717 v, 718 carry, 719 q_chunk_idx_start, 720 k_chunk_idx_start, 721 ab, 722 segment_ids, 723 True, 724 causal, 725 sm_scale, 726 block_sizes, 727 debug, 728 ) 729 return o, l, m

File /usr/local/lib/python3.10/site-packages/fjformer/pallas_operations/ring_attention/ring_attention.py:677, in _flash_attention(q, k, v, carry, q_chunk_idx_start, k_chunk_idx_start, ab, segment_ids, save_residuals, causal, sm_scale, block_sizes, debug) 662 def _flash_attention( 663 q, 664 k, (...) 675 debug, 676 ): --> 677 return _flash_attention_impl( 678 q, 679 k, 680 v, 681 carry, 682 q_chunk_idx_start, 683 k_chunk_idx_start, 684 ab, 685 segment_ids, 686 save_residuals, 687 causal, 688 sm_scale, 689 block_sizes.block_b, 690 block_sizes.block_q, 691 block_sizes.block_k_major, 692 block_sizes.block_k, 693 debug, 694 )

File /usr/local/lib/python3.10/site-packages/fjformer/pallas_operations/ring_attention/ring_attention.py:1102, in _flash_attention_impl(q, k, v, carry, q_chunk_idx_start, k_chunk_idx_start, ab, segment_ids, save_residuals, causal, sm_scale, block_b, block_q, block_k_major, block_k, debug) 1096 outspecs += [ 1097 pl.BlockSpec(lambda *: (0, 0, 0, 0), mscratch.shape), 1098 pl.BlockSpec(lambda *: (0, 0, 0, 0), lscratch.shape), 1099 pl.BlockSpec(lambda *: (0, 0, 0, 0), acc_scratch.shape), 1100 ] 1101 else: -> 1102 assert False 1103 out_shape += [None, None, None] 1104 out_specs += [None, None, None]

AssertionError:

erfanzar commented 3 months ago

Hello and thanks for using easydel Actually i can not understand your error due to a lot of markdown format breakdowns But it's coming from your partition specs and partitioning materials

Use sharing array 1,-1,1,1

IvoryTower800 commented 3 months ago

@erfanzar Thanks for your reply. I tried the sharding array 1,-1,1,1. but it raise below error. I'm really confused about the input shape and sharding strategy in easydel. I saw your examples some times use 1,-1,1,1, some times use 1,1,4,-1, or 1,1,1,-1. If it is convenient for you, could you please briefly explain how to use it in the tutorial of this project? Or provide a few generic examples. That would be very helpful.

Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parameters in train function Time Took to Complete Task configure dataloaders (microseconds) : 0.5710124969482422 Time Took to Complete Task configure Model, Optimizer, Scheduler and Config (microseconds) : 5663.623332977295 /usr/local/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py:300: UserWarning: Using Ring attention on CPUs or GPUs are not recommended due to miss computations at the moment. please refer to other types of attention mechanism.your are bing fell back on ring_attention_sharded Usage conditions was scan_ring_attention = True [MUST BE TRUE] query_states.shape1 > max(128,128)(128) warnings.warn(

ValueError Traceback (most recent call last) Cell In[3], line 129 121 dataset_train = dataset_train.map( 122 tokenization_process, 123 num_proc=os.cpu_count(), 124 remove_columns=dataset_train.column_names 125 ) 127 # you can do the same for evaluation process dataset --> 129 trainer = CausalLanguageModelTrainer( 130 train_arguments, 131 dataset_train, 132 checkpoint_path=None 133 ) 135 output = trainer.train(flax.core.FrozenDict({"params": params})) 136 print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:144, in BaseTrainer.init(self, arguments, dataset_train, dataset_eval, finetune, checkpoint_path, _do_init_fns) 138 prefix_print( 139 "Warning", 140 "In case of using finetune = True and Passing checkpoint_path = None" 141 " you should pass parameters in train function" 142 ) 143 if _do_init_fns: --> 144 self.initialize_trainer_utils() 145 else: 146 prefix_print( 147 "Warning", 148 "you have set _do_init_fns = False so function will not me initialized you have " 149 f"to do in manually (simply with trainer.initialize_trainer_utils() )" 150 )

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/base_trainer.py:227, in BaseTrainer.initialize_trainer_utils(self) 225 self.timer.log(["configure Model, Optimizer, Scheduler and Config"]) 226 self.timer("configure functions and sharding them").start() --> 227 function_configurations = self.configure_functions() 228 self.create_sharded_state_from_params_function = \ 229 function_configurations.create_sharded_state_from_params_function 230 self.sharded_train_step_function = function_configurations.sharded_train_step_function

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:256, in CausalLanguageModelTrainer.configure_functions(self) 242 else: 243 return EasyDelState( 244 step=0, 245 apply_fn=self.lora_apply_fn, (...) 253 module_config_args=None, 254 ) --> 256 state_shape = jax.eval_shape(initialize_state_function) 257 state_partition_spec = match_partition_rules( 258 self.config.get_partition_rules( 259 fully_sharded_data_parallel=self.arguments.fully_sharded_data_parallel 260 ) if self.arguments.custom_rule is None else self.arguments.custom_rule, 261 state_shape 262 ) 263 create_sharded_state_from_params_function = pjit( 264 create_state_from_params_function, 265 in_shardings=(state_partition_spec.params,), 266 out_shardings=state_partition_spec, 267 donate_argnums=(0,) 268 )

[... skipping hidden 13 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:185, in CausalLanguageModelTrainer.configure_functions..initialize_state_function() 184 def initialize_state_function(): --> 185 initialized_parameters = self.model.init_weights( 186 jax.random.PRNGKey(0), 187 self.arguments.init_input_shape 188 ) 190 if self.arguments.dtype == jnp.bfloat16: 191 initialized_parameters = self.model.to_bf16(initialized_parameters)

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:781, in FlaxPhiPreTrainedModel.init_weights(self, rng, input_shape, params) 778 params_rng, dropout_rng = jax.random.split(rng) 779 rngs = {"params": params_rng, "dropout": dropout_rng} --> 781 module_init_outputs = self.module.init(rngs, input_ids, attention_mask) 783 random_params = module_init_outputs["params"] 785 if params is not None:

[... skipping hidden 9 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:706, in FlaxPhiForCausalLMModule.call(self, input_ids, inputs_embeds, attention_mask, position_ids, extra_embedding, deterministic, output_attentions, output_hidden_states, init_cache, return_dict) 693 def call( 694 self, 695 input_ids: Optional[chex.Array] = None, (...) 704 return_dict: bool = True, 705 ) -> tuple[Any, ...] | FlaxMaskedLMOutput: --> 706 res = self.model( 707 input_ids=input_ids, 708 attention_mask=attention_mask, 709 init_cache=init_cache, 710 deterministic=deterministic, 711 extra_embedding=extra_embedding, 712 position_ids=position_ids, 713 output_attentions=output_attentions, 714 output_hidden_states=output_hidden_states, 715 return_dict=True 716 ) 717 outputs = (res.last_hidden_state, res.hidden_states, res.attentions) 718 if self.config.tie_word_embeddings:

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:638, in FlaxPhiModule.call(self, input_ids, inputs_embeds, attention_mask, position_ids, extra_embedding, deterministic, output_attentions, output_hidden_states, init_cache, return_dict) 635 assert sequence_length <= self.config.max_position_embeddings, "Maximum Position Embedding Reached !" 636 inputs_embeds = inputs_embeds + extra_embedding if extra_embedding is not None else inputs_embeds --> 638 outputs = self.layers( 639 hidden_states=inputs_embeds, 640 freq_cis=self.freq_cis, 641 attention_mask=attention_mask, 642 position_ids=position_ids, 643 causal_mask=self.causal_mask, 644 deterministic=deterministic, 645 init_cache=init_cache, 646 output_attentions=output_attentions, 647 output_hidden_states=output_hidden_states, 648 return_dict=return_dict, 649 ) 651 hidden_states = outputs[0] 652 hidden_states = self.final_layernorm(hidden_states)

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:528, in FlaxPhiDecoderLayerCollection.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, output_attentions, output_hidden_states, init_cache, return_dict) 518 all_hidden_states += (hidden_states,) 520 # hidden_states: chex.Array, 521 # freq_cis: Tuple[chex.Array, chex.Array], 522 # attention_mask: Optional[chex.Array], (...) 526 # output_attentions: bool = False, 527 # init_cache: bool = False, --> 528 layer_outputs = decoder_layer( 529 hidden_states, 530 freq_cis, 531 attention_mask, 532 position_ids, 533 causal_mask, 534 deterministic, 535 output_attentions, 536 init_cache, 537 ) 539 hidden_states = layer_outputs[0] 541 if output_attentions:

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:553, in core_remat_static..inner(scope_fn, repack_fn, variable_groups, rng_groups, args) 550 y = fn(scope, args) 551 return y, repack_fn(scope) --> 553 return rematted(variable_groups, rng_groups, *dyn_args)

[... skipping hidden 7 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:550, in core_remat_static..inner..rematted(variable_groups, rng_groups, dyn_args) 548 args = _repack_remat_args(dyn_args, static_args) 549 scope = scope_fn(variable_groups, rng_groups) --> 550 y = fn(scope, args) 551 return y, repack_fn(scope)

[... skipping hidden 3 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:425, in FlaxPhiDecoderLayer.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, output_attentions, init_cache) 422 residual = hidden_states 423 hidden_states = self.input_layernorm(hidden_states) --> 425 attn_out = self.self_attn( 426 hidden_states=hidden_states, 427 attention_mask=attention_mask, 428 position_ids=position_ids, 429 output_attentions=output_attentions, 430 deterministic=deterministic, 431 freq_cis=freq_cis, 432 causal_mask=causal_mask, 433 init_cache=init_cache, 434 ) 435 attn_outputs, self_attn_weights = (attn_out[0], attn_out[1]) if len(attn_out) == 2 else (attn_out[0], None) 437 attn_outputs = self.resid_dropout(attn_outputs, deterministic=deterministic)

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/phi/modelling_phi_flax.py:361, in FlaxPhiAttention.call(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, output_attentions, init_cache) 352 attention_bias = lax.select( 353 attention_mask > 0, 354 jnp.full(attention_mask.shape, 0.0).astype(self.dtype), 355 jnp.full(attention_mask.shape, jnp.finfo( 356 self.dtype).min).astype(self.dtype), 357 ) 359 query_length, key_length = query_states.shape[1], key_states.shape[1] --> 361 attentions = self.attention_performer.call( 362 query_states=query_states, 363 key_states=key_states, 364 value_states=value_states, 365 bias=attention_bias, 366 causal=True, 367 use_pjit_attention_force=self.config.use_pjit_attention_force, 368 dropout_rng=dropout_rng, 369 deterministic=deterministic, 370 query_sequence_length=query_length, 371 key_value_sequence_length=key_length, 372 uses_cache=self.has_variable("cache", "cached_key") or init_cache, 373 ) 374 attentions.attention_outputs = attentions.attention_outputs 375 attn_output = self._merge_heads(attentions.attention_outputs)

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py:219, in EasyAttention.call(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, segment_ids, causal, deterministic, dropout_rng, use_pjit_attention_force, uses_cache) 206 attentions = self._qkv_normal_op( 207 query_states=query_states, 208 key_states=key_states, (...) 216 key_value_sequence_length=key_value_sequence_length, 217 ) 218 elif self.attn_mechanism == "ring": --> 219 attentions = self._qkv_ring_op( 220 query_states=query_states, 221 key_states=key_states, 222 value_states=value_states, 223 bias=bias, 224 dropout_rng=dropout_rng, 225 use_pjit_attention_force=use_pjit_attention_force, 226 causal=causal, 227 deterministic=deterministic, 228 query_sequence_length=query_sequence_length, 229 key_value_sequence_length=key_value_sequence_length, 230 segment_ids=segment_ids, 231 ) 233 elif self.attn_mechanism == "splash": 234 raise NotImplementedError("Splash Attention is not Implemented YET!")

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/easy_attention.py:320, in EasyAttention._qkv_ring_op(self, query_states, key_states, value_states, query_sequence_length, key_value_sequence_length, bias, causal, deterministic, dropout_rng, use_pjit_attention_force, segment_ids) 307 query_sequence_partition = None if query_states.shape[1] == 1 else "sp" 308 ring_attention_sharded = shard_map( 309 partial(ring_attention_standard, axis_name="sp"), 310 mesh=self.mesh, (...) 318 check_rep=False 319 ) --> 320 attn_output = ring_attention_sharded( 321 query_states, key_states, value_states, bias 322 ) 323 return AttentionOutput( 324 attention_weights=None, 325 attention_outputs=attn_output 326 )

[... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/jax/experimental/shard_map.py:199, in _check_specs_vs_args(f, mesh, in_tree, in_specs, in_specs_flat, xs) 197 if any(f is not no_fail for f in fail): 198 msg = _spec_divisibility_error(f, mesh, in_tree, in_specs, fail) --> 199 raise ValueError(msg)

ValueError: shard_map applied to the function 'functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7a3e6c45b9d0>, axis_name='sp')' was given argument arrays with axis sizes that are not evenly divisible by the corresponding mesh axis sizes:

The mesh given has shape (1, 8, 1, 1) with corresponding axis names ('dp', 'fsdp', 'tp', 'sp').

Array arguments' axis sizes must be evenly divisible by the mesh axis or axes indicated by the corresponding elements of the argument's in_specs entry. Consider checking that in_specs are correct, and if so consider changing the mesh axis sizes or else padding the input and adapting 'functools.partial(<jax._src.custom_derivatives.custom_vjp object at 0x7a3e6c45b9d0>, axis_name='sp')' appropriately.

erfanzar commented 3 months ago

Yes I'll give you a tutorial or an example of using JAX ps can you please give me the model you are trying to use?

it would be good if I could know which model you're trying to use.

erfanzar commented 3 months ago

you have to use batch_size=len(jax.devices()) * batch_size

or use custom partition_spec for the attention production and q,k,v,b you can set them in the model config as you can see here when I used that for DPOTrainer ... https://github.com/erfanzar/EasyDeL?tab=readme-ov-file#dpo-fine-tuning

anyway if there's any other issue you can reach out to me or give me the code to debug ❤️.

IvoryTower800 commented 3 months ago

Hello, I managed to get the FLASH training code working by including the "init_input_shape=(8, 1024)" parameter within the TrainingArguments. Thank you for your support and patience. Below is my full working code on kaggle.

!pip install datasets !pip install git+https://github.com/erfanzar/EasyDeL.git !pip install jax[tpu]==0.4.24 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html !pip install tensorflow -U

import os

disable Weights and Biases

from datasets import load_dataset from transformers import ( AutoTokenizer,

BitsAndBytesConfig,

HfArgumentParser,
AutoTokenizer,
GenerationConfig

) from tqdm import tqdm import time import pandas as pd import numpy as np from functools import partial from transformers import set_seed from ast import literal_eval

from huggingface_hub import interpreter_login

interpreter_login()

IS_TPU=True

import jax.numpy from EasyDel import ( TrainArguments, CausalLanguageModelTrainer, AutoEasyDelModelForCausalLM, EasyDelOptimizers, EasyDelSchedulers, EasyDelGradientCheckPointers ) from datasets import load_dataset import flax from jax import numpy as jnp from transformers import AutoTokenizer from jax.sharding import PartitionSpec

disable Weights and Biases

os.environ['WANDB_DISABLED']="true" huggingface_repo_id_or_path = "microsoft/phi-2"

model, params = AutoEasyDelModelForCausalLM.from_pretrained( huggingface_repo_id_or_path,)

model.config.add_basic_configurations( attn_mechanism="flash", # Using Flash Attention here you can simply just set this to normal or ring )

max_length = 2048 tokenizer = AutoTokenizer.from_pretrained( huggingface_repo_id_or_path, trust_remote_code=True ) tokenizer.pad_token = tokenizer.eos_token configs_to_initialize_model_class = { "config": model.config, "dtype": jnp.bfloat16, "param_dtype": jnp.bfloat16, "input_shape": (8, 1024) }

print(model.config)

train_arguments = TrainArguments( model_class=type(model), model_name="my_first_model_to_train_using_easydel", num_train_epochs=3, configs_to_initialize_model_class=configs_to_initialize_model_class, custom_rule=model.config.get_partition_rules(True), learning_rate=5e-5, learning_rate_end=1e-6, max_sequence_length=max_length, optimizer=EasyDelOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported scheduler=EasyDelSchedulers.LINEAR,

"linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported

weight_decay=0.01,
total_batch_size=1,
max_training_steps=None,  # None to let trainer Decide
do_train=True,
do_eval=False,  # it's optional but supported 
backend="tpu",  # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_length=max_length,  # Note that you have to change this in the model config too
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, -1, 1, 1),  # the way to shard model across gpu,cpu or TPUs using sharding array (1, 1, 1, -1)
# everything training will be in sequence and model parallel automatic and share data between devices
use_pjit_attention_force=False,
remove_ckpt_after_load=True,
init_input_shape=(8, 1024),
gradient_accumulation_steps=8,
loss_re_mat="",
dtype=jnp.bfloat16,
use_wandb=False # This disable WANB usage

)

def ultra_chat_prompting_process( data_chunk ): user_part = [ chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user" ] assistant_part = [ chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant" ]

prompt = ""

for uc, ac in zip(user_part, assistant_part):
    prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"

return {"prompt": prompt}

tokenization_process = lambda data_chunk: tokenizer( data_chunk["prompt"], add_special_tokens=False, max_length=max_length, padding="max_length" )

dataset = load_dataset("HuggingFaceH4/ultrachat_200k") dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=os.cpu_count()) dataset_train = dataset_train.map( tokenization_process, num_proc=os.cpu_count(), remove_columns=dataset_train.column_names )

trainer = CausalLanguageModelTrainer( train_arguments, dataset_train, checkpoint_path=None )

output = trainer.train(flax.core.FrozenDict({"params": params})) print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

IvoryTower800 commented 3 months ago

However, when i changed the attn_mechanism from flash to ring, another error raised. XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication. Does it mean ring attention is not supported by tpu v3?

Information : track_memory is set to False by default inorder make make training faster. you can turn it on with just passingtrack_memory=Truein TrainArguments /usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by promote_options='default'. table = cls._concat_blocks(blocks, axis=0) Warning : In case of usingfinetune = Trueand Passingcheckpoint_path = None` you should pass parameters in train function Time Took to Complete Task configure dataloaders (microseconds) : 0.25081634521484375 Time Took to Complete Task configure Model, Optimizer, Scheduler and Config (microseconds) : 5573.355436325073 Time Took to Complete Task configure functions and sharding them (microseconds) : 6001.889228820801 Action : Sharding Passed Parameters Model Contain 2.77968384 Billion Parameters 0%| | 0/96012 [00:00<?, ?it/s]

XlaRuntimeError Traceback (most recent call last) Cell In[1], line 159 151 # you can do the same for evaluation process dataset 153 trainer = CausalLanguageModelTrainer( 154 train_arguments, 155 dataset_train, 156 checkpoint_path=None 157 ) --> 159 output = trainer.train(flax.core.FrozenDict({"params": params})) 160 print(f"Hey ! , here's where your model saved {output.checkpoint_path}")

File /usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py:491, in CausalLanguageModelTrainer.train(self, modelparameters, state) 489 = batch.pop(ssb, None) 490 time_s = time.time() --> 491 sharded_state, loss, accuracy = self.sharded_train_step_function( 492 sharded_state, 493 batch 494 ) 495 ttl_time = time.time() - time_s 496 loss_sum = loss.tolist() if loss_sum is None else loss_sum + loss

[... skipping hidden 14 frame]

File /usr/local/lib/python3.10/site-packages/jax/_src/compiler.py:236, in backend_compile(backend, module, options, host_callbacks) 231 return backend.compile(built_c, compile_options=options, 232 host_callbacks=host_callbacks) 233 # Some backends don't have host_callbacks option yet 234 # TODO(sharadmv): remove this fallback when all backends allow compile 235 # to take in host_callbacks --> 236 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Unsupported input data type in matrix multiplication.

The MLIR operation involved: %1145 = "tpu.matmul"(%1143, %1144, %549) {transpose_lhs = false, transpose_rhs = false} : (vector<128x128xbf16>, vector<128x128xbf16>, vector<128x128xf32>) -> vector<128x128xf32> ... additional diagnostics were skipped.

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke`

erfanzar commented 3 months ago

you are using code in wrong way and ... here I created a kaggle example for you. https://www.kaggle.com/citifer/easydel-causal-language-model-trainer-example

IvoryTower800 commented 3 months ago

Thank you! That's really helpful. btw, does 4bit or 8bit training supported by easydel on TPU v3 now? I saw this comments on the document. "Right now im looking to make EasyBITs in EasyDel work on TPU-v3 cause on low amp GPUs and old TPUs it might now work as good as it does on TPU-v4/5"

I tried with add this line of code: model.config.bits = 4. but an error occurs from fjformer that precesion defualt requested with quantization. So how should I modify my code to enable 4bit or 8bit training?

I'm sorry I asked you so many questions... Thank you very much!

erfanzar commented 3 months ago

I'm glad that you find the example helpful. Yes, I'll create another example for that too but I haven't tried to fix the custom EasyBITs speed issue yet. and feel free to ask any question.

erfanzar commented 3 months ago

is issue closed?