NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.32k stars 2.31k forks source link

[BUG] GPTDataset._build_document_sample_shuffle_indices does not build the indices on non-root nodes when not using NFS #907

Open dementrock opened 3 months ago

dementrock commented 3 months ago

Describe the bug If the training data does not live on NFS but on node-specific storage, the current logic in https://github.com/NVIDIA/Megatron-LM/blob/0bc3547702464501feefeb5523b7a17e591b21fa/megatron/core/datasets/gpt_dataset.py#L346 skips building the indices and result in an error when loading the document index at https://github.com/NVIDIA/Megatron-LM/blob/0bc3547702464501feefeb5523b7a17e591b21fa/megatron/core/datasets/gpt_dataset.py#L484, complaining that the file does not exist.

To Reproduce Try running multi-node training, pointing to training data not living on NFS.

Expected behavior Ideally there should be a flag indicating whether the data storage is shared file system. If not, the index needs to be built on each node separately.

Stack trace/logs

(worker6, rank=6, pid=8930, ip=10.42.3.242)   File "/opt/megatron-lm/megatron/core/datasets/blended_megatron_dataset_builder.py", line 470, in build_generic_dataset
(worker6, rank=6, pid=8930, ip=10.42.3.242)     dataset = cls(*args)
(worker6, rank=6, pid=8930, ip=10.42.3.242)   File "/opt/megatron-lm/megatron/core/datasets/gpt_dataset.py", line 111, in __init__
(worker6, rank=6, pid=8930, ip=10.42.3.242)     ) = self._build_document_sample_shuffle_indices()
(worker6, rank=6, pid=8930, ip=10.42.3.242)   File "/opt/megatron-lm/megatron/core/datasets/gpt_dataset.py", line 474, in _build_document_sample_shuffle_indices
(worker6, rank=6, pid=8930, ip=10.42.3.242)     document_index = numpy.load(path_to_document_index, allow_pickle=True, mmap_mode='r')
(worker6, rank=6, pid=8930, ip=10.42.3.242)   File "/usr/local/lib/python3.10/dist-packages/numpy/lib/npyio.py", line 405, in load
(worker6, rank=6, pid=8930, ip=10.42.3.242)     fid = stack.enter_context(open(os_fspath(file), "rb"))
(worker6, rank=6, pid=8930, ip=10.42.3.242) FileNotFoundError: [Errno 2] No such file or directory: '/wiki/mistral_7b_v0.3_training_data_text_document/cache/GPTDataset_indices/81e3d4d910e734899c56ceb4ba98b98c-GPTDataset-train-document_index.npy'

Environment (please complete the following information):

Proposed fix My workaround is the following patch:

--- /opt/megatron-lm/megatron/core/datasets/gpt_dataset.py      2024-07-07 03:48:09.635073980 +0000
+++ /opt/megatron-lm/megatron/core/datasets/gpt_dataset.py.new  2024-07-07 03:48:07.383130640 +0000
@@ -8,6 +8,7 @@

 import numpy
 import torch
+import torch.distributed

 from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
 from megatron.core.datasets.indexed_dataset import IndexedDataset
@@ -342,7 +343,7 @@

         if not path_to_cache or (
             not cache_hit
-            and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0)
+            and (not torch.distributed.is_initialized() or os.environ.get('LOCAL_RANK', '0') == '0')
         ):

             log_single_rank(
@@ -459,7 +460,9 @@
             )
             log_single_rank(logger, logging.INFO, f"> total number of epochs: {num_epochs}")

-            return document_index, sample_index, shuffle_index
+            # return document_index, sample_index, shuffle_index
+
+        torch.distributed.barrier()

         log_single_rank(
             logger, logging.INFO, f"Load the {type(self).__name__} {self.index_split.name} indices"

But it does not offer the flexibility of a flag.

nakroy commented 3 months ago

Met with the same problem when using 2 nodes 16 GPUs to finetune llama2 model. I use NFS to synchronize dataset files, but it will still cause a FileNotFoundError in the second node, although it already synchronized the cache train dataset files after the first node generated them

Dune-Z commented 3 months ago

Met the same problem when using multiple nodes and moving dataset to shared disk solved the problem.

nakroy commented 3 months ago

Met with the same problem when using 2 nodes 16 GPUs to finetune llama2 model. I use NFS to synchronize dataset files, but it will still cause a FileNotFoundError in the second node, although it already synchronized the cache train dataset files after the first node generated them

Set the seed and use NFS to synchronize dataset solve the problem

github-actions[bot] commented 1 month ago

Marking as stale. No activity in 60 days.

nakroy commented 3 weeks ago

Hi, you can check this simple tutorial to learn how to use NFS to share files between different nodes. https://bluexp.netapp.com/blog/azure-anf-blg-linux-nfs-server-how-to-set-up-server-and-client

I recommend that you use the master node to install nfs-kernel-server, and other nodes to install nfs-client. Creating a dataset directory (for your Megatron-LM training) on the master node, and define access for other nodes in export file, so that you can synchronize training dataset between all your nodes.

By the way, it may not work when the first time you launch training scripts ( it seems like a bug that dist.barrier() did not work well in new pytorch version, but the master node would successfully generate the train files. So you can launch it again, because the second time it wouldn't generate train files again and start training.)

Nakroy @.***

 

------------------ 原始邮件 ------------------ 发件人: "NVIDIA/Megatron-LM" @.>; 发送时间: 2024年9月23日(星期一) 下午4:23 @.>; @.**@.>; 主题: Re: [NVIDIA/Megatron-LM] [BUG] GPTDataset._build_document_sample_shuffle_indices does not build the indices on non-root nodes when not using NFS (Issue #907)

Met with the same problem when using 2 nodes 16 GPUs to finetune llama2 model. I use NFS to synchronize dataset files, but it will still cause a FileNotFoundError in the second node, although it already synchronized the cache train dataset files after the first node generated them

Set the seed and use NFS to synchronize dataset solve the problem

Hello, I met the same problem. And I don't understand how to use NFS. Could you please share me this method?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you commented.Message ID: @.***>