pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.29k stars 115 forks source link

Make dataloader stateful? #291

Closed XinDongol closed 1 month ago

XinDongol commented 2 months ago

Resuming from checkpoint uses the same dataloader from begining currently. This may lead to issues for training. We may need to resume dataloader from saved state to skip sampled data.

tianyu-l commented 2 months ago

Thanks for creating this issue! In fact we recently started working on it, in #279.

XinDongol commented 2 months ago

Tested the branch

    File "torchtitan/train.py", line 255, in main
      checkpoint.load()
    File "torchtitan/torchtitan/checkpoint.py", line 217, in load
      dcp.load(
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/utils.py", line 427, in inner_func
      return func(*args, **kwargs)
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 174, in load
      elem.load_state_dict(statetful_sd[key])
    File "torchtitan/torchtitan/checkpoint.py", line 79, in load_state_dict
      self.dataloader.load_state_dict(state_dict[self.rank_id])
  KeyError: 1

It is weird but all ranks are receiving the same state_dict (rank 0's state_dict)

Not quite sure about what happen here but this might be the reason,

https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict_loader.load WARNING All tensors in state_dict must be allocated on their destination device prior to calling this function. All non-tensor data is loaded using torch.load() and modified in place on state_dict.

tianyu-l commented 2 months ago

@XinDongol Thanks for trying out! The work is still in progress and not ready. Yes, we are aware of the issue you mentioned and we are working on distributed checkpointing to fix it. We hope to resolve this problem soon.

XinDongol commented 2 months ago

I was wondering whether you find the root cause of all ranks receiving the same state_dict of dataloader? I guess that it is because the state_dict is not in DTensor? (not sure at all)

tianyu-l commented 2 months ago

I was wondering whether you find the root cause of all ranks receiving the same state_dict of dataloader? I guess that it is because the state_dict is not in DTensor? (not sure at all)

Yes. Currently if it is not DTensor, only rank 0's value is saved. After the fix, we'd like to be able to save values across all ranks if the keys of state_dict are different per each rank.

XinDongol commented 2 months ago

@tianyu-l @gokulavasan Thanks for reply. One more note I want to mention here is that the current implementation does not support num_worker>1. If we set num_worker>1, different workers will load duplicated samples. This is a common issue with IterableDataset as discussed here but can be solved with tricks. When fixing the state_dict issue of dataloader, it would be great if you can take this into consideration.

tianyu-l commented 2 months ago

@XinDongol Thanks for the note! I believe we are aware of the issue (@gokulavasan to double check).

The reason we didn't prioritize supporting num_worker>1 is that the llama training is GPU bounded, so even if we load data using the main process, this part of data loading work can be overlapped by the remaining GPU work in the last iteration. Besides, the time spent on data loading is almost negligible compared with the time spent on training.

For these reasons, we think it's better not to introduce the additional complexity. However, things may change if we are going to support multi-model training, as the loading time of image / video could be much longer. Happy to hear your thoughts on it.

XinDongol commented 2 months ago

I agree. For very large model, it may be the case.

Torchtitan is currently doing on-the-fly tokenization. I really like the idea of on-the-fly-tokenization which is great for SFT and makes changing tokenizer very easy. I did a profiling and found that on-the-fly tokenization is 3x slower than pre-tokenization when num_worker=1. https://github.com/XinDongol/on-the-fly-tokenization-profiling

Increasing num_workers makes on-the-fly tokenization even faster because reading texts is more IO efficient than reading tokens. I tried a 1B model and found that data loading time is about 10% of end-to-end time when num_workers=1 for torchtitan with on-the-fly tokenization. Using num_workers=8 reduces it to 1%. So I think supporting num_workers>1 could be still helpful.

tianyu-l commented 1 month ago

@XinDongol Appreciated your feedback a lot!

I tried a 1B model and found that data loading time is about 10% of end-to-end time when num_workers=1 for torchtitan with on-the-fly tokenization.

I wonder what log_freq you used for this experiment? Every log step would insert a CPU/GPU synchronize (for the loss) which would reduce the overlap opportunity. If log_freq = 1 for this experiment, I wonder how much time data loading would take if we set log_freq = 10 or something else.