erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
209 stars 24 forks source link

Training on TPU Using Flash Attention #83

Closed IvoryTower800 closed 10 months ago

IvoryTower800 commented 10 months ago

Hi, I tried finetune a model on TPU VM v3-8. when not using flash attention, it works. However, when I set config.use_flash_attention =True, an error occurs: block_q=1024 should be smaller or equal to q_seq_len=1.

When I tried to set config.q_seq_len = 4096. it doesn't work, it still report the same error: block_q=1024 should be smaller or equal to q_seq_len=1.

Below is my Code:

def main(argv):

dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset['test_sft'].map(formatting_func, num_proc=12)
dataset_train = dataset_train.remove_columns(['prompt','prompt_id','messages'])

params, config = llama_from_pretrained(FLAGS.pretrained_model_name_or_path,jax.devices("cpu")[0])
config.use_flash_attention =True
config.q_seq_len = 4096

config.flash_attn_key_chunk_size = 1

config.flash_attn_query_chunk_size = 1

train_args = TrainArguments(
    model_class=EasyDel.modules.FlaxLlamaForCausalLM,
    configs_to_init_model_class={
        'config': config,
        'dtype': get_dtype(FLAGS.dtype),
        'param_dtype': get_dtype(FLAGS.dtype)
    },
    custom_rule=config.get_partition_rules(True),
    model_name=FLAGS.project_name,
    num_train_epochs=FLAGS.num_train_epochs,
    learning_rate=FLAGS.learning_rate,
    learning_rate_end=FLAGS.learning_rate_end,
    optimizer=FLAGS.optimizer,
    scheduler=FLAGS.scheduler,
    weight_decay=0.01,
    total_batch_size=1,
    gradient_accumulation_steps=32,
    max_steps=FLAGS.max_steps,
    do_train=FLAGS.do_train,
    do_eval=FLAGS.do_eval,
    do_test=FLAGS.do_test,
    backend=FLAGS.backend,
    max_length=FLAGS.max_sequence_length,
    gradient_checkpointing='nothing_saveable',
    sharding_array=(1, -1, 1, 1),
    use_pjit_attention_force=False,

    remove_ckpt_after_load=FLAGS.remove_ckpt_after_load,

)

trainer = CausalLanguageModelTrainer(train_args,
                                     dataset_train=dataset_train,
                                     dataset_eval=dataset_train['eval'] if FLAGS.do_eval else None,
                                     checkpoint_path=FLAGS.checkpoint_path)
output = trainer.train(
    model_parameters=flax.core.FrozenDict({'params': params})
)
# Done You can simply train any llama LLM that you want in less than 50 lines of code

if name == "main": app.run(main)

/root /usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'. table = cls._concat_blocks(blocks, axis=0) Loading checkpoint shards: 100%|██████████████████| 2/2 [00:14<00:00, 7.20s/it] Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function wandb: Tracking run with wandb version 0.16.2 wandb: W&B syncing is set to offline in this directory.
wandb: Run wandb online or set WANDB_MODE=online to enable cloud syncing. Time For configure dataloaders (ms) : 0.2694129943847656 I0119 04:13:23.447986 132476669524864 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/train.py", line 237, in app.run(main) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/root/train.py", line 226, in main trainer = CausalLanguageModelTrainer(train_args, File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 243, in init self.init_functions() File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 300, in init_functions self.model, self.tx, self.scheduler, self.config = self.configure_model() File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 397, in configure_model model = self.arguments.model_class( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 657, in init super().init(config, module, input_shape=input_shape, File "/usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 223, in init params_shape_tree = jax.eval_shape(init_fn, self.key) File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 692, in init_weights module_init_outputs = self.module.init( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call outputs = self.model( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call outputs = self.layers( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call layer_outputs = block( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call attn_outputs = self.self_attn( File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner return rematted(variable_groups, rng_groups, dyn_args) File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted y = fn(scope, args) File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call attn_output = smart_flash_attention( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention return _flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention return _flash_attention_impl( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 590, in _flash_attention_impl _verify_block("block_q", "q_seq_len", block_q, q_seq_len, should_divide=False) File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 1689, in _verify_block raise ValueError( ValueError: block_q=1024 should be smaller or equal to q_seq_len=1 jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/train.py", line 237, in app.run(main) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/root/train.py", line 226, in main trainer = CausalLanguageModelTrainer(train_args, File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 243, in init self.init_functions() File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 300, in init_functions self.model, self.tx, self.scheduler, self.config = self.configure_model() File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 397, in configure_model model = self.arguments.model_class( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 657, in init super().init(config, module, input_shape=input_shape, File "/usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 223, in init params_shape_tree = jax.eval_shape(init_fn, self.key) File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 692, in init_weights module_init_outputs = self.module.init( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call outputs = self.model( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call outputs = self.layers( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call layer_outputs = block( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call attn_outputs = self.self_attn( File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner return rematted(variable_groups, rng_groups, dyn_args) File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted y = fn(scope, args) File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call attn_output = smart_flash_attention( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention return _flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention return _flash_attention_impl( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 590, in _flash_attention_impl _verify_block("block_q", "q_seq_len", block_q, q_seq_len, should_divide=False) File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 1689, in _verify_block raise ValueError( ValueError: block_q=1024 should be smaller or equal to q_seq_len=1 wandb: You can sync this run to the cloud by running: wandb: wandb sync /root/wandb/offline-run-20240119_041323-e4ftdqry wandb: Find logs at: ./wandb/offline-run-20240119_041323-e4ftdqry/logs

Then I tried to set config.flash_attn_key_chunk_size = 1 and config.flash_attn_query_chunk_size = 1. another error occured: TypeError: fori_loop() got an unexpected keyword argument 'unroll'

/root /usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'. table = cls._concat_blocks(blocks, axis=0) Loading checkpoint shards: 100%|██████████████████| 2/2 [00:14<00:00, 7.42s/it] Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function wandb: Tracking run with wandb version 0.16.2 wandb: W&B syncing is set to offline in this directory.
wandb: Run wandb online or set WANDB_MODE=online to enable cloud syncing. Time For configure dataloaders (ms) : 0.25153160095214844 I0119 04:18:41.780579 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:41.856101 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:41.903009 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:41.948627 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:41.995650 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.042542 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.089266 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.137497 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.184210 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.230629 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.276757 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.324035 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.370467 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.417459 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.465906 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.514631 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.562183 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.608217 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.654218 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.700267 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.746966 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.793409 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.841850 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.889965 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.936056 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:42.982169 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.028893 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.074969 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.122670 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.169237 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.216310 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.262662 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. Time For configure Model ,Optimizer ,Scheduler and Config (ms) : 1676.0611534118652 I0119 04:18:43.363086 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.410817 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.457525 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.503836 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.549518 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.595661 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.643612 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.690201 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.736696 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.781970 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.828232 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.873488 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.919017 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:43.964708 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.010466 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.057410 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.103181 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.151406 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.197674 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.244695 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.291322 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.337508 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.384516 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.430277 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.476433 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.522479 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.568389 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.614069 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.660615 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.707407 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.754585 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:44.800200 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. I0119 04:18:45.564801 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. Time For configure functions and sharding them (ms) : 2244.947671890259 Action : Sharding Passed Parameters Model Contain 6.929256448 Billion Parameters 0%| | 0/7220 [00:00<?, ?it/s]I0119 04:20:07.023622 138762858945408 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/train.py", line 223, in app.run(main) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/root/train.py", line 216, in main output = trainer.train( File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 708, in train sharded_state, loss, accuracy = self.sharded_train_step_fn( File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 89, in casual_language_model_train_step (loss, accuracy), grad = grad_fn(state.params) File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 80, in calculate_loss logits = state.apply_fn(params=params, batch, File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 809, in call outputs = self.module.apply( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call outputs = self.model( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call outputs = self.layers( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call layer_outputs = block( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call attn_outputs = self.self_attn( File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner return rematted(variable_groups, rng_groups, dyn_args) File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted y = fn(scope, args) File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call attn_output = smart_flash_attention( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention return _flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention return _flash_attention_impl( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 746, in _flash_attention_impl o, aux = pl.pallas_call( File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 379, in wrapped gridmapping, jaxpr, consts, = _trace_to_jaxpr( File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 338, in _trace_tojaxpr jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 337, in _flash_attention_kernel kernel((batch_idx, 0), q_tile_ref, args, kwargs) File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 384, in _flash_attention_kernel_single_batch def run(): File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/utils.py", line 29, in _wrapped f() File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attntpu.py", line 388, in run def body(i, ): TypeError: fori_loop() got an unexpected keyword argument 'unroll' jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/train.py", line 223, in app.run(main) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/root/train.py", line 216, in main output = trainer.train( File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 708, in train sharded_state, loss, accuracy = self.sharded_train_step_fn( File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 89, in casual_language_model_train_step (loss, accuracy), grad = grad_fn(state.params) File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 80, in calculate_loss logits = state.apply_fn(params=params, batch, File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 809, in call outputs = self.module.apply( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call outputs = self.model( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call outputs = self.layers( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call layer_outputs = block( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call attn_outputs = self.self_attn( File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner return rematted(variable_groups, rng_groups, dyn_args) File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted y = fn(scope, args) File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call attn_output = smart_flash_attention( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention return _flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention return _flash_attention_impl( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 746, in _flash_attention_impl o, aux = pl.pallas_call( File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 379, in wrapped gridmapping, jaxpr, consts, = _trace_to_jaxpr( File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 338, in _trace_tojaxpr jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 337, in _flash_attention_kernel kernel((batch_idx, 0), q_tile_ref, args, kwargs) File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 384, in _flash_attention_kernel_single_batch def run(): File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/utils.py", line 29, in _wrapped f() File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attntpu.py", line 388, in run def body(i, ): TypeError: fori_loop() got an unexpected keyword argument 'unroll' wandb: wandb: Run history: wandb: Number of Model Parameters (Billion) ▁ wandb: wandb: Run summary: wandb: Number of Model Parameters (Billion) 6.92926 wandb: wandb: You can sync this run to the cloud by running: wandb: wandb sync /root/wandb/offline-run-20240119_041841-decs7x5r wandb: Find logs at: ./wandb/offline-run-20240119_041841-decs7x5r/logs

Could you please guide me how to use flash attention on TPU when training model using easydel? Thank you so much!

erfanzar commented 10 months ago

try this one

train_args = TrainArguments(
    model_class=EasyDel.modules.FlaxLlamaForCausalLM,
    configs_to_init_model_class={
        'config': config,
        'dtype': get_dtype(FLAGS.dtype),
        'param_dtype': get_dtype(FLAGS.dtype),
        'input_shape':(NUM_TPU_FSDP_MESH_ORDER, BLOCK_Q_TPU_FLASH_ATTENTION)
    },
    custom_rule=config.get_partition_rules(True),
    model_name=FLAGS.project_name,
    num_train_epochs=FLAGS.num_train_epochs,
    learning_rate=FLAGS.learning_rate,
    learning_rate_end=FLAGS.learning_rate_end,
    optimizer=FLAGS.optimizer,
    scheduler=FLAGS.scheduler,
    weight_decay=0.01,
    total_batch_size=1,
    gradient_accumulation_steps=32,
    max_steps=FLAGS.max_steps,
    do_train=FLAGS.do_train,
    do_eval=FLAGS.do_eval,
    do_test=FLAGS.do_test,
    backend=FLAGS.backend,
    max_length=FLAGS.max_sequence_length,
    gradient_checkpointing='nothing_saveable',
    sharding_array=(1, -1, 1, 1),
    use_pjit_attention_force=False,

    remove_ckpt_after_load=FLAGS.remove_ckpt_after_load,

)

trainer = CausalLanguageModelTrainer(train_args,
                                     dataset_train=dataset_train,
                                     dataset_eval=dataset_train['eval'] if FLAGS.do_eval else None,
                                     checkpoint_path=FLAGS.checkpoint_path)
output = trainer.train(
    model_parameters=flax.core.FrozenDict({'params': params})
)
erfanzar commented 10 months ago

set BLOCK_Q and BLOCK_K in llama config they will be used in code like this

...
                block_q=self.config.flash_attn_query_chunk_size,
                block_k=self.config.flash_attn_key_chunk_size,
...
IvoryTower800 commented 10 months ago

@erfanzar Thank you. I tried your code with 'input_shape':(1, 4096). But I got below new errors: TypeError: fori_loop() got an unexpected keyword argument 'unroll'. it seems like this error was from fjformer?

def main(argv): dataset = load_dataset("HuggingFaceH4/ultrachat_200k") dataset_train = dataset['test_sft'].map(formatting_func, num_proc=12) dataset_train = dataset_train.remove_columns(['prompt','prompt_id','messages'])

params, config = llama_from_pretrained(FLAGS.pretrained_model_name_or_path,jax.devices("cpu")[0])
config.use_flash_attention =True
train_args = TrainArguments(
    model_class=EasyDel.modules.FlaxLlamaForCausalLM,
    configs_to_init_model_class={
        'config': config,
        'dtype': get_dtype(FLAGS.dtype),
        'param_dtype': get_dtype(FLAGS.dtype),
        'input_shape':(1, 4096)
    },
    custom_rule=config.get_partition_rules(True),
    model_name=FLAGS.project_name,
    num_train_epochs=FLAGS.num_train_epochs,
    learning_rate=FLAGS.learning_rate,
    learning_rate_end=FLAGS.learning_rate_end,
    optimizer=FLAGS.optimizer,
    scheduler=FLAGS.scheduler,
    weight_decay=0.01,
    total_batch_size=1,
    gradient_accumulation_steps=1,
    max_steps=FLAGS.max_steps,
    do_train=FLAGS.do_train,
    do_eval=FLAGS.do_eval,
    do_test=FLAGS.do_test,
    backend=FLAGS.backend,
    max_length=FLAGS.max_sequence_length,
    gradient_checkpointing='nothing_saveable',
    sharding_array=(1, -1, 1, 1),
    use_pjit_attention_force=False,

    remove_ckpt_after_load=FLAGS.remove_ckpt_after_load,

)

trainer = CausalLanguageModelTrainer(train_args,
                                     dataset_train=dataset_train,
                                     dataset_eval=dataset_train['eval'] if FLAGS.do_eval else None,
                                     checkpoint_path=FLAGS.checkpoint_path)
output = trainer.train(
    model_parameters=flax.core.FrozenDict({'params': params})
)
# Done You can simply train any llama LLM that you want in less than 50 lines of code

if name == "main": app.run(main)

/root /usr/local/lib/python3.10/site-packages/jax/_src/cloud_tpu_init.py:73: UserWarning: JAX_USE_PJRT_C_API_ON_TPU no longer has an effect (the new TPU runtime is always enabled now). Unset the environment variable to disable this warning. warnings.warn( /usr/local/lib/python3.10/site-packages/datasets/table.py:1421: FutureWarning: promote has been superseded by mode='default'. table = cls._concat_blocks(blocks, axis=0) Loading checkpoint shards: 100%|██████████████████| 2/2 [00:14<00:00, 7.11s/it] Warning : In case of using finetune = True and Passing checkpoint_path = None you should pass parametersin train function wandb: Tracking run with wandb version 0.16.2 wandb: W&B syncing is set to offline in this directory.
wandb: Run wandb online or set WANDB_MODE=online to enable cloud syncing. Time For configure dataloaders (ms) : 0.2560615539550781 I0119 10:41:33.663878 135426770996096 mesh_utils.py:71] Reordering mesh to physical ring order on single-tray TPU v2/v3. jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/train.py", line 224, in app.run(main) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/root/train.py", line 213, in main trainer = CausalLanguageModelTrainer(train_args, File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 243, in init self.init_functions() File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 300, in init_functions self.model, self.tx, self.scheduler, self.config = self.configure_model() File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 397, in configure_model model = self.arguments.model_class( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 657, in init super().init(config, module, input_shape=input_shape, File "/usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 223, in init params_shape_tree = jax.eval_shape(init_fn, self.key) File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 692, in init_weights module_init_outputs = self.module.init( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call outputs = self.model( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call outputs = self.layers( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call layer_outputs = block( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call attn_outputs = self.self_attn( File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner return rematted(variable_groups, rng_groups, dyn_args) File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted y = fn(scope, args) File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call attn_output = smart_flash_attention( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention return _flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention return _flash_attention_impl( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 746, in _flash_attention_impl o, aux = pl.pallas_call( File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 379, in wrapped gridmapping, jaxpr, consts, = _trace_to_jaxpr( File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 338, in _trace_tojaxpr jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 337, in _flash_attention_kernel kernel((batch_idx, 0), q_tile_ref, args, **kwargs) File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 384, in _flash_attention_kernel_single_batch def run(): File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/utils.py", line 29, in _wrapped f() File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attntpu.py", line 388, in run def body(i, ): TypeError: fori_loop() got an unexpected keyword argument 'unroll' jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/root/train.py", line 224, in app.run(main) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/usr/local/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/root/train.py", line 213, in main trainer = CausalLanguageModelTrainer(train_args, File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 243, in init self.init_functions() File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 300, in init_functions self.model, self.tx, self.scheduler, self.config = self.configure_model() File "/usr/local/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 397, in configure_model model = self.arguments.model_class( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 657, in init super().init(config, module, input_shape=input_shape, File "/usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 223, in init params_shape_tree = jax.eval_shape(init_fn, self.key) File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 692, in init_weights module_init_outputs = self.module.init( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1124, in call outputs = self.model( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 1021, in call outputs = self.layers( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 911, in call layer_outputs = block( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 575, in call attn_outputs = self.self_attn( File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 553, in inner return rematted(variable_groups, rng_groups, dyn_args) File "/usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py", line 550, in rematted y = fn(scope, args) File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py", line 348, in call attn_output = smart_flash_attention( File "/usr/local/lib/python3.10/site-packages/EasyDel/modules/flax_modelling_utils.py", line 455, in smart_flash_attention attn_output = fjformer.attention.jax_flash_attn_tpu.flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 198, in flash_attention return _flash_attention( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 216, in _flash_attention return _flash_attention_impl( File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 746, in _flash_attention_impl o, aux = pl.pallas_call( File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 379, in wrapped gridmapping, jaxpr, consts, = _trace_to_jaxpr( File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py", line 338, in _trace_tojaxpr jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, jaxpr_flat_avals, File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 337, in _flash_attention_kernel kernel((batch_idx, 0), q_tile_ref, args, **kwargs) File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attn_tpu.py", line 384, in _flash_attention_kernel_single_batch def run(): File "/usr/local/lib/python3.10/site-packages/jax/_src/pallas/utils.py", line 29, in _wrapped f() File "/usr/local/lib/python3.10/site-packages/fjformer/attention/jax_flash_attntpu.py", line 388, in run def body(i, ): TypeError: fori_loop() got an unexpected keyword argument 'unroll' wandb: You can sync this run to the cloud by running: wandb: wandb sync /root/wandb/offline-run-20240119_104133-trahclvu wandb: Find logs at: ./wandb/offline-run-20240119_104133-trahclvu/logs

erfanzar commented 10 months ago

this issue soon ill be fixed.

erfanzar commented 10 months ago

Fixed not you have to set

config.attn_mechanism = "flash"
config.block_k = 128 # Key State Chunk
config.block_q = 128 # Query State Chunk
config.block_b = 1
config.block_k_major = 128 # Key State Chunk
config.use_flash_attention = True

and everything should be good to go

IvoryTower800 commented 10 months ago

@erfanzar Thank you for your patient guidance. but I'm still not able to run the code without error. Here is the platform, environment, my latest full code (the example code on the homepage of this repo) and output on kaggle: https://www.kaggle.com/code/natalina1/easydel

Could you please tell me how to modify the code to fix the error? Many thanks!

erfanzar commented 10 months ago

Ill soon create a tutorial for that after fixing MoE Issues

IvoryTower800 commented 10 months ago

@erfanzar Thank you so much!

erfanzar commented 10 months ago

hi this is a simple example test this one and in case that you got any error tell me and make sure your FJFormer version in above it equal 0.0.28 and you are on TPU