Open yzhangcs opened 1 week ago
cc @muellerzr @lhoestq
Hey, just giving my 2 cents since unless I'm missing something, this seems extremely simple to implement since StatefulDataLoader
is a drop-in replacement for DataLoader. (I.e. literally just replace DataLoader construction w/ StatefulDataLoader construction in trainer.py
)
If it's simple enough I could probably take a shot at implementing it?
The only caveat is it seems torchdata.stateful_dataloader
is a beta feature only available in the nightly
build. Does it make sense to officially support unreleased features?
@byi8220 Hi, as I can see, the hf Trainer uses the accelerate library internally to prepare the dataloader. This process involves returning self-defined classes like DataLoaderShard
to handle cases involving distributed data dispatch. I think it might be challenging to directly combine the Trainer with StatefulDataLoader without delving into the intricate details of the Trainer implementation.
Hm, maybe I misunderstand the problem. My understanding is that what we are focused on is that when the Trainer is loading from a checkpoint, it calls skip_first_batches to skip past the beginning of the dataset until the DataLoader iterator is pointing to where it was at that checkpoint.
And for an IterableDataset, the way this is done under the hood is that it has to manually loop over the items to iterate. And StatefulDataLoader may solve this problem by allowing one to call load_state_dict somewhere in the Trainer while loading the checkpoint, and writing the StatefulDataLoader's state dict to the checkpoint.
This process involves returning self-defined classes like DataLoaderShard to handle cases involving distributed data dispatch.
Yes, it seems like DataLoaderShard and DataLoaderDispatcher are created in the prepare_data_loader function and skip_first_batches function in the accelerate library. These classes are both subclasses of DataLoader, so likely need to be modified or copied to extend from StatefulDataLoader
So IIUC, it seems maybe the implementation of this feature would involve the following steps?
accelerate
library, add either refactor DataLoaderShard and DataLoaderDispatcher to compose or add new variants that inherit from a StatefulDataLoader.Trainer
class, allow dropping in StatefulDataLoader instead of a regular DataLoaderTrainer
class, support loading and saving the state_dict to and from the checkpointThanks for point this out. I still might not be understanding correctly. Maybe it's a lot more complicated than this.
Correct, we need to:
StatefulDataLoader
in accelerate
and use it as an optional alternative in the DataLoaderConfiguration
Trainer
!Makes sense. It also seems like there's a related issue raised in accelerate
: https://github.com/huggingface/accelerate/issues/2859
Regarding using it in the trainer
, it feels a bit awkward. IIUC, the desired behavior is that if a StatefulDataLoader is being used, and loading from a checkpoint, then it should not call skip_first_batches
at all, unless you are passing in the state dict and checkpoints to that function as well. But imo it feels like skip_first_batches
and "restore from checkpoint" are two separate concepts.
Thank you for your responses @byi8220 @muellerzr.
Yes, I agree with you that if we properly manage the states of dataloaders in the Trainer, we no longer need to use the accelerate skip_first_batches
option for recovery.
As a workaround, I bypass accelerate to prepare my dataloaders by hacking the Trainer class to support stateful ones:
class Trainer(transformers.Trainer):
def get_train_dataloader(self) -> DPAwareDataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
logger.info(f"Split the dataset for the node at rank {self.args.process_index} / {self.args.world_size}.")
train_dataset = HuggingFaceDataset(self.train_dataset,
self.tokenizer,
self.args.context_length,
self.args.process_index,
self.args.world_size)
loader = DPAwareDataLoader(rank=self.args.process_index,
dataset=train_dataset,
batch_size=self.args.train_batch_size,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
persistent_workers=self.args.dataloader_persistent_workers)
data_callback = DataCallback(loader)
self.add_callback(data_callback)
return loader
The DPAwareDataLoader is borrowed from torchtitan
's impls. This pkg is also developing similar ideas. Then making use of self-defined callbacks to save/load states
class DataCallback(TrainerCallback, ExportableState):
def __init__(self, loader):
self.loader = loader
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
output_dir = None
if isinstance(args.resume_from_checkpoint, bool):
if args.resume_from_checkpoint:
output_dir = get_last_checkpoint(args.output_dir)
elif args.resume_from_checkpoint is not None:
output_dir = args.resume_from_checkpoint
if output_dir is not None:
if args.world_size <= 1:
data_state_pth = os.path.join(output_dir, "data_state.json")
else:
data_state_pth = os.path.join(output_dir, f"data_state_{args.process_index}.json")
with open(data_state_pth, "r") as f:
self.loader.load_state_dict(json.loads(f.read()))
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
if args.world_size <= 1:
data_state_pth = os.path.join(output_dir, "data_state.json")
else:
data_state_pth = os.path.join(output_dir, f"data_state_{args.process_index}.json")
with open(data_state_pth, "w") as f:
f.write(json.dumps(self.state(), indent=2, sort_keys=True) + "\n")
def state(self) -> dict:
return self.loader.state_dict()
skip_first_batches
is ignored by --ignore_data_skip
.
I performed some minimal unit tests, and the states were successfully recovered without perceiving any delays.
This approach can be extremely useful when performing online tokenization with IterableDataset. Some people have conducted benchmarks and observed even faster speeds than pre-tokenization in https://github.com/XinDongol/on-the-fly-tokenization-profiling. I've tried using stateful loaders with the above ugly hacking code to train the mamba model on the subsets of 627B Slimpajama data, reducing the total training time from ~173h to ~170h. This could also save ~3TB of disk space compared to pre-tokenized map-style data.
So I'm really looking forward to your official impls, very happy to hear about any progress :D
Feature request
Add official support for
StatefulDataLoader
as in torchdata and datasets.Motivation
The StatefulDataLoader from the torchdata package provides a convenient way to recover a dataset iterator that was interrupted, without having to skip the first batches via a naive for loop, which can be time-consuming for extremely large datasets. The
datasets
package now officially supports statefulIterableDataset
and its combination withStatefulDataLoader
in v2.20.0.Example usage:
To enhance the usability and efficiency of the
Trainer
, it would be highly beneficial for the community if official support forStatefulDataLoader
could be added. This would allow users to easily recover from interruptions and resume training from checkpoints without wasting time on re-iterating over already processed batches. By integratingStatefulDataLoader
into theTrainer
, users can seamlessly handle large datasets and ensure a smooth training process. This feature would greatly improve the overall user experience and make the Trainer more robust and efficient. We kindly request the development team to consider adding official support for thoese features in theTrainer
, as it would be a valuable addition to the library and benefit the wider community.