Removes with zero first handling in sample packing function. This function is already called from utils.data.sft.prepare_dataset under a zero first guard. In fact, if you keep it and process the dataset in a multi gpu setting, when you provide a eval set explicitly it causes incorrect synchornization
i. utils.data.sft.prepare_dataset establishes a barrier() on rank0, all other ranks stopped
ii. load_prepare_datasets > load_tokenized_prepared_datasets > process_datasets_for_packing is called for train set and again establishes a barrier() on rank 0, all other ranks still stopped.
iii. when process_datasets_for_packing finishes, barrier() is called on all other ranks which allows the ranks blocked in i. to resume and will lead to other ranks trying to load the validation set before it has been processed by rank 0
Adds an option drop_long_sequences which when set to False allows the user to raise an error if sequences longer than sequence_len are found in the dataset. Consequently this also requires truncation to be turned off in prompt strategies. Also adds a sequence lengths histogram printer util
Adds dropping of sequences with no learnable outputs. It can happen that with empty completions or messages or with truncation=True + train_on_inputs=False can lead to all -100 labels with nothing to learn - such samples make the loss nan
Adds a check and error if the dataset ends up empty after all this processing
Things I am not sure about yet and would love some help:
Turning trucation off in all prompt strategies and its consequences
Changes for DPO
Things Todo:
Writing a proper test
For context, my employer has been running these changes for quite a while now (https://github.com/truefoundry/axolotl/pull/5/files) but we majorly use chat_template prompt strategy and SFT but axololt is quite a large surface area code base so we would need some help 😅
A summary of changes:
Removes with zero first handling in sample packing function. This function is already called from
utils.data.sft.prepare_dataset
under a zero first guard. In fact, if you keep it and process the dataset in a multi gpu setting, when you provide a eval set explicitly it causes incorrect synchornizationi.
utils.data.sft.prepare_dataset
establishes a barrier() on rank0, all other ranks stopped ii.load_prepare_datasets
>load_tokenized_prepared_datasets
>process_datasets_for_packing
is called for train set and again establishes a barrier() on rank 0, all other ranks still stopped. iii. whenprocess_datasets_for_packing finishes
, barrier() is called on all other ranks which allows the ranks blocked in i. to resume and will lead to other ranks trying to load the validation set before it has been processed by rank 0Adds an option
drop_long_sequences
which when set toFalse
allows the user to raise an error if sequences longer thansequence_len
are found in the dataset. Consequently this also requires truncation to be turned off in prompt strategies. Also adds a sequence lengths histogram printer utilAdds dropping of sequences with no learnable outputs. It can happen that with empty completions or messages or with truncation=True + train_on_inputs=False can lead to all -100 labels with nothing to learn - such samples make the loss nan
Adds a check and error if the dataset ends up empty after all this processing
Things I am not sure about yet and would love some help:
Things Todo:
For context, my employer has been running these changes for quite a while now (https://github.com/truefoundry/axolotl/pull/5/files) but we majorly use chat_template prompt strategy and SFT but axololt is quite a large surface area code base so we would need some help 😅