Modalities / modalities

A framework for training multimodal foundation models.
MIT License
38 stars 3 forks source link

Fix Training Step Logging & Log Number of Consumed Tokens #137

Closed mali-git closed 1 month ago

le1nux commented 1 month ago

If you look at the config file now, e.g. here in L8-11, I think there is a general problem:

For the parameters global_training_log_interval_in_steps, global_checkpointing_interval_in_steps & global_evaluation_interval_in_steps, "steps" corresponds to "optimizer steps". In contrast, for the parameter global_num_seen_steps (and the related skip_num_micro_steps), "steps" refers to "micro batch steps".

This seems confusing. Maybe we should either have this difference explicitly reflected in the names of the parameters (e.g. global_num_seen_steps -> global_num_seen_micro_steps), or make further changes such that "steps" always refers to the same thing.

Based on your proposal, I would suggest the following changes:

  1. rename all global steps variables, i.e., global_training_log_interval_in_steps, global_checkpointing_interval_in_steps, global_evaluation_interval_in_steps, global_num_seen_steps -> training_log_interval_in_steps, checkpointing_interval_in_steps, evaluation_interval_in_steps, num_seen_steps, since FSDP and DDP don't have the concept of local vs global steps. We should still distinct between local and global batch sizes though!

  2. Dataloader: To define the number of batches to be skipped for the dataloader during warmstart, I would suggest we use the variable skip_num_batches instead of skip_num_micro_steps. The dataloader does not have to know about things like (micro) steps. The calculation should happen outside, possibly even manually in the beginning. The calculation would be

    skip_num_samples = old_batch_size*old_gradient_accumulation_steps*old_num_steps*old_num_ranks
    skip_num_batches = skip_num_samples // (current_batch_size*current_num_ranks)

    When changing the the batch_size and num_ranks between previous run and warmstart, we might see a few samples twice in this case.

  3. Rename tokens_per_train_step to global_num_tokens_per_train_step

What do you think?

flxst commented 1 month ago
  1. I think that's a good idea.
  2. This would also be an improvement in my opinion. However, can't we go a step further and specify the number of skipped samples (skip_num_samples) directly? Or even skipped tokens (skip_num_tokens)? This would be particularly convenient if the number of consumed samples / tokens was written to the checkpoint. The number of skipped batches (skip_num_batches) may then be derived using the new (global) batch size internally, i.e. the user doesn't need to think about this.
  3. Wouldn't this contradict the logic behind the suggested changes in 1.? "steps" is short for "optimizer steps", so it is clear that tokens_per_train_step refers to the global batch size. We could actually call it global_batch_size as well :)
le1nux commented 1 month ago
  1. ok I'll change this then
  2. Good idea! If we pass the global_batch_size and context_size to CheckpointExexecution, then we can calculate the number of seen samples and number of seen tokens there and save it as part of the checkpoint file name, e.g., eid_{experiment_id}-{entity}-num_steps_{num_train_steps}-num_samples_{num_samples}-num_tokens_{num_tokens}.bin. For a warmstart, we would pass in the number of seen samples or number of seen tokens to the dataloader factory.
  3. From my point of view this would not contradict. There are local_num_tokens_per_train_step which is the number of seen tokens within one step on a single rank and global_num_tokens_per_train_step which is local_num_tokens_per_train_step*num_ranks. So "global" and "local" does not refer to the step but to the num_tokens.