erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Step time increasing as training progresses #90

Closed yhavinga closed 5 months ago

yhavinga commented 5 months ago

In one of the longer training runs that is now running on a tpu-v3-8 I noticed the training ETA kept getting later and later. Also in the step-time wandb log (picture below) the higher the step number, the longer the lookup time.

image

Any ideas what could be the cause? I looked a bit at DataLoader prefetch_factor but its only available when using multiprocessing / num_workers > 0

PS: Thanks for creating EasyDel - its amazing what you've created!

erfanzar commented 5 months ago

hello thank you :).

and about the issue can i have access to your training arguments (like batch size and etc) and i would also like to have screen shot of you buffer size chart in WANDB.

yhavinga commented 5 months ago

For this specific run the config was like this. (looking at it, maybe it is due to the shuffle() operation?) -- unfortunately TRC access ended today, will retry without shuffle after renewing.

dataset_train = datasets.load_dataset("yhavinga/nedd_x_chat_instruct_tokenized_zephyr_7b_alpha_padright__b1_", split="train").shuffle()
context_length = len(dataset_train[0]['input_ids'])  # 1024
print(f"Using context length of {context_length}")

tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
model, params = AutoEasyDelModelForCausalLM.from_pretrained(model_id)
config = model.config
config.freq_max_position_embeddings = config.max_position_embeddings  # 32768
config.max_position_embeddings = context_length
config.c_max_position_embeddings = config.max_position_embeddings

max_length = config.max_position_embeddings

train_args = TrainArguments(
    model_class=EasyDel.FlaxMistralForCausalLM,
    configs_to_init_model_class={
        'config': config,
        'dtype': jnp.bfloat16,
        'param_dtype': jnp.bfloat16,
        'input_shape': (1, 1)
    },
    custom_rule=config.get_partition_rules(True),
    model_name='TowerDutchTest',
    num_train_epochs=1,
    learning_rate=1e-5,
    learning_rate_end=1e-6,
    optimizer=EasyDelOptimizers.ADAMW,
    scheduler=EasyDelSchedulers.WARM_UP_LINEAR,
    warmup_steps=500,
    weight_decay=0.1,
    total_batch_size=2,
    max_steps=48000,
    save_steps=8000,
    do_train=True,
    do_eval=False,
    backend='tpu',
    max_length=max_length,
    gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
    sharding_array=(1, -1, 1, 1),
    use_pjit_attention_force=False,
    use_flash_attention=False,
    gradient_accumulation_steps=8,
    remove_ckpt_after_load=True,
    ids_to_pop_from_dataset=['token_type_ids'],
    loss_remat="",
    dtype=jnp.bfloat16
)

image Wandb link : https://wandb.ai/yepster/EasyDeL-TowerDutchTest/runs/6cweckwk/workspace

yhavinga commented 5 months ago

To check if shuffling a dataset could cause these increasing times I plotted per sample reading time of a HF shuffled dataset. This is reading with an iterator, but indexed lookup looked almost the same

image

erfanzar commented 5 months ago

I don't see any issue with your configurations and I guess it might be related to any of the back processes that you might be running or your kernel running for you I have pre-trained more than 10+ models and this is the first time I'm seeing buffer size increasing over time

image

Do you have any suggestion for me to fix your issue, i guess you can try disabling shuffle in TrainingArguments and see if the buffer size is still increasing.

erfanzar commented 5 months ago

this issue is being closed due to no response has been given

yhavinga commented 5 months ago

Got renewed TRC access and looked into it a bit more. I tried a couple of things:

  1. replace jnp with np in the data-collator: synthetic test showed per loop time much faster than jnp, unfortunately didn't seem to solve increasing step time.
  2. replace jnp with np in the train loops perplexity calculation
  3. commented out mean loss and mean accuracy stats - arrays they operate on grow with step size
  4. set track_mem to off

Result below: green is unchanged easydel 0.0.42 tag orange is with only change 1. on easydel from few days ago red is with the four changes above, on easydel from few days ago

image

erfanzar commented 5 months ago

Hello and thank for letting me know where the issue is or how i can make it better

I have tried using numpy instead in jax numpy but that will cause issues and a lot of errors in multiple host training

If you want you can create a pull request from HEAD and make changes other wise you can tell me which parts exactly you want to be modified in order to make it more efficient.

yhavinga commented 5 months ago

I will do some more experimentation, I suspect only changing the mean calculations is sufficient. When finished I'll create a PR.