Closed mali-git closed 1 month ago
skip_num_samples
) directly? Or even skipped tokens (skip_num_tokens
)? This would be particularly convenient if the number of consumed samples / tokens was written to the checkpoint. The number of skipped batches (skip_num_batches
) may then be derived using the new (global) batch size internally, i.e. the user doesn't need to think about this.tokens_per_train_step
refers to the global batch size. We could actually call it global_batch_size
as well :) global_batch_size
and context_size
to CheckpointExexecution
, then we can calculate the number of seen samples and number of seen tokens there and save it as part of the checkpoint file name, e.g., eid_{experiment_id}-{entity}-num_steps_{num_train_steps}-num_samples_{num_samples}-num_tokens_{num_tokens}.bin
. For a warmstart, we would pass in the number of seen samples or number of seen tokens to the dataloader factory. local_num_tokens_per_train_step
which is the number of seen tokens within one step on a single rank and global_num_tokens_per_train_step
which is local_num_tokens_per_train_step*num_ranks
. So "global" and "local" does not refer to the step but to the num_tokens.
Based on your proposal, I would suggest the following changes:
rename all global steps variables, i.e.,
global_training_log_interval_in_steps
,global_checkpointing_interval_in_steps
,global_evaluation_interval_in_steps
,global_num_seen_steps
->training_log_interval_in_steps
,checkpointing_interval_in_steps
,evaluation_interval_in_steps
,num_seen_steps
, since FSDP and DDP don't have the concept of local vs global steps. We should still distinct between local and global batch sizes though!Dataloader: To define the number of batches to be skipped for the dataloader during warmstart, I would suggest we use the variable
skip_num_batches
instead ofskip_num_micro_steps
. The dataloader does not have to know about things like (micro) steps. The calculation should happen outside, possibly even manually in the beginning. The calculation would beWhen changing the the batch_size and num_ranks between previous run and warmstart, we might see a few samples twice in this case.
Rename
tokens_per_train_step
toglobal_num_tokens_per_train_step
What do you think?