Closed IvoryTower800 closed 10 months ago
try this one
train_args = TrainArguments(
model_class=EasyDel.modules.FlaxLlamaForCausalLM,
configs_to_init_model_class={
'config': config,
'dtype': get_dtype(FLAGS.dtype),
'param_dtype': get_dtype(FLAGS.dtype),
'input_shape':(NUM_TPU_FSDP_MESH_ORDER, BLOCK_Q_TPU_FLASH_ATTENTION)
},
custom_rule=config.get_partition_rules(True),
model_name=FLAGS.project_name,
num_train_epochs=FLAGS.num_train_epochs,
learning_rate=FLAGS.learning_rate,
learning_rate_end=FLAGS.learning_rate_end,
optimizer=FLAGS.optimizer,
scheduler=FLAGS.scheduler,
weight_decay=0.01,
total_batch_size=1,
gradient_accumulation_steps=32,
max_steps=FLAGS.max_steps,
do_train=FLAGS.do_train,
do_eval=FLAGS.do_eval,
do_test=FLAGS.do_test,
backend=FLAGS.backend,
max_length=FLAGS.max_sequence_length,
gradient_checkpointing='nothing_saveable',
sharding_array=(1, -1, 1, 1),
use_pjit_attention_force=False,
remove_ckpt_after_load=FLAGS.remove_ckpt_after_load,
)
trainer = CausalLanguageModelTrainer(train_args,
dataset_train=dataset_train,
dataset_eval=dataset_train['eval'] if FLAGS.do_eval else None,
checkpoint_path=FLAGS.checkpoint_path)
output = trainer.train(
model_parameters=flax.core.FrozenDict({'params': params})
)
set BLOCK_Q and BLOCK_K in llama config they will be used in code like this
...
block_q=self.config.flash_attn_query_chunk_size,
block_k=self.config.flash_attn_key_chunk_size,
...
@erfanzar Thank you. I tried your code with 'input_shape':(1, 4096). But I got below new errors: TypeError: fori_loop() got an unexpected keyword argument 'unroll'. it seems like this error was from fjformer?
def main(argv): dataset = load_dataset("HuggingFaceH4/ultrachat_200k") dataset_train = dataset['test_sft'].map(formatting_func, num_proc=12) dataset_train = dataset_train.remove_columns(['prompt','prompt_id','messages'])
params, config = llama_from_pretrained(FLAGS.pretrained_model_name_or_path,jax.devices("cpu")[0])
config.use_flash_attention =True
train_args = TrainArguments(
model_class=EasyDel.modules.FlaxLlamaForCausalLM,
configs_to_init_model_class={
'config': config,
'dtype': get_dtype(FLAGS.dtype),
'param_dtype': get_dtype(FLAGS.dtype),
'input_shape':(1, 4096)
},
custom_rule=config.get_partition_rules(True),
model_name=FLAGS.project_name,
num_train_epochs=FLAGS.num_train_epochs,
learning_rate=FLAGS.learning_rate,
learning_rate_end=FLAGS.learning_rate_end,
optimizer=FLAGS.optimizer,
scheduler=FLAGS.scheduler,
weight_decay=0.01,
total_batch_size=1,
gradient_accumulation_steps=1,
max_steps=FLAGS.max_steps,
do_train=FLAGS.do_train,
do_eval=FLAGS.do_eval,
do_test=FLAGS.do_test,
backend=FLAGS.backend,
max_length=FLAGS.max_sequence_length,
gradient_checkpointing='nothing_saveable',
sharding_array=(1, -1, 1, 1),
use_pjit_attention_force=False,
remove_ckpt_after_load=FLAGS.remove_ckpt_after_load,
)
trainer = CausalLanguageModelTrainer(train_args,
dataset_train=dataset_train,
dataset_eval=dataset_train['eval'] if FLAGS.do_eval else None,
checkpoint_path=FLAGS.checkpoint_path)
output = trainer.train(
model_parameters=flax.core.FrozenDict({'params': params})
)
# Done You can simply train any llama LLM that you want in less than 50 lines of code
if name == "main": app.run(main)
/root
/usr/local/lib/python3.10/site-packages/jax/_src/cloud_tpu_init.py:73: UserWarning: JAX_USE_PJRT_C_API_ON_TPU no longer has an effect (the new TPU runtime is always enabled now). Unset the environment variable to disable this warning.
warnings.warn(
/usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'.
table = cls._concat_blocks(blocks, axis=0)
Loading checkpoint shards: 100%|██████████████████| 2/2 [00:14<00:00, 7.11s/it]
Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function
wandb: Tracking run with wandb version 0.16.2
wandb: W&B syncing is set to offline
in this directory.
wandb: Run wandb online
or set WANDB_MODE=online to enable cloud syncing.
Time For configure dataloaders (ms) : 0.2560615539550781
I0119 10:41:33.663878 135426770996096 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/train.py", line 224, in
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/train.py", line 224, in
this issue soon ill be fixed.
Fixed not you have to set
config.attn_mechanism = "flash"
config.block_k = 128 # Key State Chunk
config.block_q = 128 # Query State Chunk
config.block_b = 1
config.block_k_major = 128 # Key State Chunk
config.use_flash_attention = True
and everything should be good to go
@erfanzar Thank you for your patient guidance. but I'm still not able to run the code without error. Here is the platform, environment, my latest full code (the example code on the homepage of this repo) and output on kaggle: https://www.kaggle.com/code/natalina1/easydel
Could you please tell me how to modify the code to fix the error? Many thanks!
Ill soon create a tutorial for that after fixing MoE Issues
@erfanzar Thank you so much!
hi this is a simple example test this one and in case that you got any error tell me and make sure your FJFormer version in above it equal 0.0.28 and you are on TPU
Hi, I tried finetune a model on TPU VM v3-8. when not using flash attention, it works. However, when I set config.use_flash_attention =True, an error occurs: block_q=1024 should be smaller or equal to q_seq_len=1.
When I tried to set config.q_seq_len = 4096. it doesn't work, it still report the same error: block_q=1024 should be smaller or equal to q_seq_len=1.
Below is my Code:
def main(argv):
config.flash_attn_key_chunk_size = 1
config.flash_attn_query_chunk_size = 1
if name == "main": app.run(main)
/root /usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'. table = cls._concat_blocks(blocks, axis=0) Loading checkpoint shards: 100%|██████████████████| 2/2 [00:14<00:00, 7.20s/it] Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function wandb: Tracking run with wandb version 0.16.2 wandb: W&B syncing is set to
offline
in this directory.wandb: Run
wandb online
or set WANDB_MODE=online to enable cloud syncing. Time For configure dataloaders (ms) : 0.2694129943847656 I0119 04:13:23.447986 132476669524864 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/root/train.py", line 237, in
app.run(main)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/train.py", line 226, in main
trainer = CausalLanguageModelTrainer(train_args,
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 243, in init
self.init_functions()
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 300, in init_functions
self.model, self.tx, self.scheduler, self.config = self.configure_model()
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 397, in configure_model
model = self.arguments.model_class(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 657, in init
super().init(config, module, input_shape=input_shape,
File "/usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 223, in init
params_shape_tree = jax.eval_shape(init_fn, self.key)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 692, in init_weights
module_init_outputs = self.module.init(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call
outputs = self.layers(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call
layer_outputs = block(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call
attn_outputs = self.self_attn(
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, dyn_args)
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, args)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call
attn_output = smart_flash_attention(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention
attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention
return _flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention
return _flash_attention_impl(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 590, in _flash_attention_impl
_verify_block("block_q", "q_seq_len", block_q, q_seq_len, should_divide=False)
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 1689, in _verify_block
raise ValueError(
ValueError: block_q=1024 should be smaller or equal to q_seq_len=1
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/root/train.py", line 237, in
app.run(main)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/train.py", line 226, in main
trainer = CausalLanguageModelTrainer(train_args,
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 243, in init
self.init_functions()
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 300, in init_functions
self.model, self.tx, self.scheduler, self.config = self.configure_model()
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 397, in configure_model
model = self.arguments.model_class(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 657, in init
super().init(config, module, input_shape=input_shape,
File "/usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 223, in init
params_shape_tree = jax.eval_shape(init_fn, self.key)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 692, in init_weights
module_init_outputs = self.module.init(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call
outputs = self.layers(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call
layer_outputs = block(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call
attn_outputs = self.self_attn(
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, dyn_args)
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, args)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call
attn_output = smart_flash_attention(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention
attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention
return _flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention
return _flash_attention_impl(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 590, in _flash_attention_impl
_verify_block("block_q", "q_seq_len", block_q, q_seq_len, should_divide=False)
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 1689, in _verify_block
raise ValueError(
ValueError: block_q=1024 should be smaller or equal to q_seq_len=1
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /root/wandb/offline-run-20240119_041323-e4ftdqry
wandb: Find logs at: ./wandb/offline-run-20240119_041323-e4ftdqry/logs
Then I tried to set config.flash_attn_key_chunk_size = 1 and config.flash_attn_query_chunk_size = 1. another error occured: TypeError: fori_loop() got an unexpected keyword argument 'unroll'
/root /usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'. table = cls._concat_blocks(blocks, axis=0) Loading checkpoint shards: 100%|██████████████████| 2/2 [00:14<00:00, 7.42s/it] Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function wandb: Tracking run with wandb version 0.16.2 wandb: W&B syncing is set to
offline
in this directory.wandb: Run
wandb online
or set WANDB_MODE=online to enable cloud syncing. Time For configure dataloaders (ms) : 0.25153160095214844 I0119 04:18:41.780579 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:41.856101 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:41.903009 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:41.948627 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:41.995650 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.042542 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.089266 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.137497 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.184210 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.230629 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.276757 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.324035 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.370467 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.417459 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.465906 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.514631 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.562183 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.608217 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.654218 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.700267 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.746966 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.793409 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.841850 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.889965 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.936056 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.982169 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.028893 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.074969 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.122670 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.169237 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.216310 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.262662 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. Time For configure Model ,Optimizer ,Scheduler and Config (ms) : 1676.0611534118652 I0119 04:18:43.363086 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.410817 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.457525 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.503836 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.549518 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.595661 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.643612 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.690201 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.736696 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.781970 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.828232 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.873488 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.919017 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.964708 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.010466 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.057410 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.103181 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.151406 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.197674 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.244695 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.291322 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.337508 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.384516 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.430277 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.476433 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.522479 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.568389 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.614069 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.660615 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.707407 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.754585 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.800200 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:45.564801 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. Time For configure functions and sharding them (ms) : 2244.947671890259 Action : Sharding Passed Parameters Model Contain 6.929256448 Billion Parameters 0%| | 0/7220 [00:00<?, ?it/s]I0119 04:20:07.023622 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/root/train.py", line 223, in
app.run(main)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/train.py", line 216, in main
output = trainer.train(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 708, in train
sharded_state, loss, accuracy = self.sharded_train_step_fn(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 89, in casual_language_model_train_step
(loss, accuracy), grad = grad_fn(state.params)
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 80, in calculate_loss
logits = state.apply_fn(params=params, batch,
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 809, in call
outputs = self.module.apply(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call
outputs = self.layers(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call
layer_outputs = block(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call
attn_outputs = self.self_attn(
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, dyn_args)
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, args)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call
attn_output = smart_flash_attention(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention
attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention
return _flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention
return _flash_attention_impl(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 746, in _flash_attention_impl
o, aux = pl.pallas_call(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 379, in wrapped
gridmapping, jaxpr, consts, = _trace_to_jaxpr(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 338, in _trace_tojaxpr
jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals,
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 337, in _flash_attention_kernel
kernel((batch_idx, 0), q_tile_ref, args, kwargs)
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 384, in _flash_attention_kernel_single_batch
def run():
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/utils.py", line 29, in _wrapped
f()
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attntpu.py", line 388, in run
def body(i, ):
TypeError: fori_loop() got an unexpected keyword argument 'unroll'
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/root/train.py", line 223, in
app.run(main)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/root/train.py", line 216, in main
output = trainer.train(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 708, in train
sharded_state, loss, accuracy = self.sharded_train_step_fn(
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 89, in casual_language_model_train_step
(loss, accuracy), grad = grad_fn(state.params)
File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 80, in calculate_loss
logits = state.apply_fn(params=params, batch,
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 809, in call
outputs = self.module.apply(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call
outputs = self.model(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call
outputs = self.layers(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call
layer_outputs = block(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call
attn_outputs = self.self_attn(
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, dyn_args)
File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, args)
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call
attn_output = smart_flash_attention(
File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention
attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention
return _flash_attention(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention
return _flash_attention_impl(
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 746, in _flash_attention_impl
o, aux = pl.pallas_call(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 379, in wrapped
gridmapping, jaxpr, consts, = _trace_to_jaxpr(
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 338, in _trace_tojaxpr
jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals,
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 337, in _flash_attention_kernel
kernel((batch_idx, 0), q_tile_ref, args, kwargs)
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 384, in _flash_attention_kernel_single_batch
def run():
File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/utils.py", line 29, in _wrapped
f()
File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attntpu.py", line 388, in run
def body(i, ):
TypeError: fori_loop() got an unexpected keyword argument 'unroll'
wandb:
wandb: Run history:
wandb: Number of Model Parameters (Billion) ▁
wandb:
wandb: Run summary:
wandb: Number of Model Parameters (Billion) 6.92926
wandb:
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /root/wandb/offline-run-20240119_041841-decs7x5r
wandb: Find logs at: ./wandb/offline-run-20240119_041841-decs7x5r/logs
Could you please guide me how to use flash attention on TPU when training model using easydel? Thank you so much!