NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
10.12k stars 2.28k forks source link

[BUG] Mismatch Between Docstring and Behavior in core.tensor_parallel.random.model_parallel_cuda_manual_seed #858

Open cong-bai opened 3 months ago

cong-bai commented 3 months ago

Describe the bug The following function (in megatron.core.tensor_parallel.random) is called when we initialize the random seeds. Now I am suspecting the behavior of this function doesn't match the docstring, even if we consider different ways of passing seeds to the function:

def model_parallel_cuda_manual_seed(seed):
    """Initialize model parallel cuda seed.

    This function should be called after the model parallel is
    initialized. Also, no torch.cuda.manual_seed should be called
    after this function. Basically, this is replacement for that
    function.
    Two set of RNG states are tracked:
    default state: This is for data parallelism and is the same among a set of model parallel GPUs but different across different model paralle groups. This is used for example for dropout in the non-tensor-model-parallel regions.
    tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions.
    """
    # 2718 is just for fun and any POSITIVE value will work.
    offset = seed + 2718
    tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
    # Data parallel gets the original seed.
    data_parallel_seed = seed

    initialize_rng_tracker()
    _CUDA_RNG_STATE_TRACKER.reset()
    # Set the default state.
    torch.cuda.manual_seed(data_parallel_seed)
    _CUDA_RNG_STATE_TRACKER.add(_DATA_PARALLEL_RNG_TRACKER_NAME, data_parallel_seed)

    # and model parallel state.
    _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)

    expert_parallel_seed = (
        seed + 1024 + 100 * get_expert_model_parallel_rank() + get_tensor_model_parallel_rank()
    )
    _CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed)

To Reproduce N/A

Expected behavior I guess the docstring states the intention (and is the expected behavior) of this function

Stack trace/logs N/A

Environment (please complete the following information):

Proposed fix N/A

Additional context I suspect model_parallel_cuda_manual_seed are meant to take care of all the details about seeding in different parallelisms. If this is true, then the following usage (in megatron/training/initialize.py) seems problematic:

def _set_random_seed(seed_, data_parallel_random_init=False):
    """Set random seed for reproducability."""
    if seed_ is not None and seed_ > 0:
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
        # Ensure different data parallel ranks get different seeds
        if data_parallel_random_init:
            seed = seed + (10 * mpu.get_data_parallel_rank())
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.device_count() > 0:
            tensor_parallel.model_parallel_cuda_manual_seed(seed)
    else:
        raise ValueError("Seed ({}) should be a positive integer.".format(seed))
github-actions[bot] commented 1 month ago

Marking as stale. No activity in 60 days.