NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.72k stars 2.4k forks source link

[BUG] `run_simple_mcore_train_loop.py` bugs when moditied `tensor_model_parallel_size` from `2` to `1` #1038

Open 1195343015 opened 3 months ago

1195343015 commented 3 months ago

Describe the bug https://github.com/NVIDIA/Megatron-LM/blob/01ca03f11e89f4f85682dcac647c2b913b25fcee/examples/run_simple_mcore_train_loop.py#L118 When I moditied tensor_model_parallel_size in run_simple_mcore_train_loop.py from 2 to 1, some bugs happened.

Stack trace/logs

/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [37,0,0], thread: [92,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [37,0,0], thread: [93,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [37,0,0], thread: [94,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/pytorch/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [37,0,0], thread: [95,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/Megatron-LM/examples/run_simple_mcore_train_loop.py", line 154, in <module>
[rank0]:     losses_reduced = forward_backward_func(
[rank0]:   File "/workspace/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 1379, in forward_backward_pipelining_without_interleaving
[rank0]:     output_tensor, num_tokens = forward_step(
[rank0]:   File "/workspace/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 259, in forward_step
[rank0]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank0]:   File "/workspace/Megatron-LM/examples/run_simple_mcore_train_loop.py", line 107, in forward_step_func
[rank0]:     output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/workspace/Megatron-LM/megatron/core/models/gpt/gpt_model.py", line 186, in forward
[rank0]:     decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1552, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1561, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/workspace/Megatron-LM/megatron/core/models/common/embeddings/language_model_embedding.py", line 103, in forward
[rank0]:     embeddings = word_embeddings + position_embeddings
[rank0]: RuntimeError: CUDA error: device-side assert triggered
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Environment (please complete the following information):

mksit commented 3 months ago

Do you have any solution? I got the same error.

1195343015 commented 2 months ago

Do you have any solution? I got the same error.

I think this bug is due to inappropriate default max_sequence_length in MockGPTLowLevelDataset, which is used to generate mockdataset. https://github.com/NVIDIA/Megatron-LM/blob/732a689606810c02d0dc260a163c9ebac099c044/megatron/core/datasets/gpt_dataset.py#L693-L697 The default max_sequence_length is 4096. You can modify it to 64, which makes it the same as run_simple_mcore_train_loop.py https://github.com/NVIDIA/Megatron-LM/blob/732a689606810c02d0dc260a163c9ebac099c044/examples/run_simple_mcore_train_loop.py#L21 Hope it could be helpful for you.

mksit commented 2 months ago

@1195343015 Thanks a lot. It works for me. I hope this can be fixed more robustly.

github-actions[bot] commented 4 weeks ago

Marking as stale. No activity in 60 days.