microsoft / DeepSpeed

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
https://www.deepspeed.ai/
Apache License 2.0
35.37k stars 4.1k forks source link

Performance Degradation with ZERO Stage 3 #1069

Closed stephenrawls closed 1 year ago

stephenrawls commented 3 years ago

Hi,

I am trying to benchmark a 10B parameter Huggingface RobertaForMaskedLM model with both ZERO Stage 2 and ZERO Stage 3 to compare the latency impact of parameter partitioning.

I am seeing much worse performance with Stage 3 than expected however, so want to check if something looks wrong.

|--------------------+-------+--------+-------+---------+--------|
| Description        | Model | # p3dn | batch | Samples | TFLOPS |
|                    | Size  |  hosts |  size | Per Sec |  / GPU |
|--------------------+-------+--------+-------+---------+--------|
| Baseline (stage2)  | 10B   |     16 |     8 |     178 |   44.7 |
| Stage3, no offload | 10B   |     16 |     8 |      77 |   19.5 |
| Stage3, no offload | 10B   |      8 |     8 |      41 |   21.9 |
| Stage3, no offload | 10B   |      4 |     8 |      23 |   23.5 |
| Stage3, no offload | 10B   |      2 |     8 |    11.6 |   23.5 |
| Stage3, no offload | 10B   |      1 |     8 |     OOM |    OOM |
|--------------------+-------+--------+-------+---------+--------|

The problem does not seem to be related to network bandwidth, because when I move to p4d machines, which have 4x the bandwidth of p3dn machines (400 Gbps vs 100 Gbps) I see similar degradation:

|--------------------+-------+--------+-------+---------+--------|
| Description        | Model | # p4dn | batch | Samples | TFLOPS |
|                    | Size  |  hosts |  size | Per Sec |  / GPU |
|--------------------+-------+--------+-------+---------+--------|
| Baseline (stage2)  | 10B   |     16 |     8 |     432 |    109 |
| Stage3, no offload | 10B   |      4 |     8 |      44 |   44.5 |
|--------------------+-------+--------+-------+---------+--------|

I tried increasing stage3_max_live_parameters from 1e9 → 2e9, and stage3_prefetch_bucket_size from 5e8 → 10e8, but neither change impacted performance.

In addition, I ended up adding some time.time() statements before + after: a. The blocking fetch() call: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L1520 b. The non-blocking pre-fetch() call: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L1525-L1528 c. The release() call: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L1540

And I noticed counter-intuitively that the majority of time was spent in what is supposed to be the non-blocking pre-fetch call:

Total fetch time = 599.5581150054932 ms;
Total pre-fetch time = 4473.618030548096 ms;
Total release time = 1130.7482719421387 ms

Total time = 6203.9244174957275 ms

In fact after a bit of digging and some additional timing statements added to code, I isolated the place that is causing pre-fetch to take so long to this line: https://github.com/microsoft/DeepSpeed/blob/18a26e8604c4cb8562ed8d57241ca64dbeb4318a/deepspeed/runtime/zero/partition_parameters.py#L798

Any ideas why I am seeing a 2x or bigger drop in performance when moving to stage 3 (compared to stage 2)? And why pre-fetching seems to be taking so much time when it is supposed to be an asynchronous background operation?

Thanks, Stephen

P.S. Details are here: Model config:

RobertaForMaskedLM:
    max_position_embeddings: 512
    type_vocab_size: 1
    num_attention_heads: 40
    num_hidden_layers: 30
    hidden_size: 5120
    intermediate_size: 20480
    gradient_checkpointing: true

Zero stage 2 config:

  zero_optimization:
    stage: 2
    overlap_comm: true

Zero stage 3 config:

  zero_optimization:
    stage: 3
    overlap_comm: true
tjruwase commented 3 years ago

@stephenrawls, can you please share the command line for these results?

jfc4050 commented 3 years ago

hello, i am seeing a similar gap. in my case i am training a ~2B parameter roberta model on a single node w/ 8 GPUs. Getting ~52 TFLOPS using stage 2, and ~19 with stage3, this gap seems quite a bit larger than expected.

here's a minimal script to reproduce. https://gist.github.com/jfc4050/ab1ecddc9c13290ecc1dde4a4eee0358

i've hardcoded the relevant configuration into the script but its as follows

    ds_config = {
        "memory_breakdown": True,
        "wall_clock_breakdown": True,
        "train_micro_batch_size_per_gpu": 8,
        "gradient_accumulation_steps": 1,
        "gradient_clipping": 1.0,
        "fp16": {"enabled": True},
        "zero_optimization": {
            "stage": 3,
            "overlap_comm": True,
            "contiguous_gradients": False,
            "reduce_bucket_size": 1.5e8,
            "allgather_bucket_size": 1.5e8,
        },
        "activation_checkpointing": {
            "partition_activations": False,
            "contiguous_memory_optimization": False,
            "number_checkpoints": 50,
            "cpu_checkpointing": False,
            "profile": True,
        },
    }
jfc4050 commented 3 years ago

image

looking at this random subsection of a train run, seems like there's a lot of overhead from the pre and post sub module hooks, similar to what @stephenrawls saw. i'm not seeing the same overhead from the non-blocking prefetch call though. the fetch operation taking some time seems reasonable, but we also spend quite a lot of time doing what appears to be releasing parameters.

the GPU isnt all that busy during this time, hopefully there's room to push the utilization up a bit

wrapping up for the day but tomorrow will spend some time familiarizing myself with the code and looking a little more closely at the profile.

jfc4050 commented 3 years ago

i tried putting the post submodule forward parameter deallocation into its own stream and making it non blocking

before:

-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step:                   10
Number of parameters:           394
Number of multiply-accumulate operations (MACs):   8267.5 G
Number of floating point operations ( = 2 * MACs):   16534.99 G
Latency:                        838.62 ms
Floating point operations per second(FLOPS):   19.72 TFLOPS

after:

-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step:                   10
Number of parameters:           394
Number of multiply-accumulate operations (MACs):   8267.5 G
Number of floating point operations ( = 2 * MACs):   16534.99 G
Latency:                        621.79 ms
Floating point operations per second(FLOPS):   26.59 TFLOPS

seems to have gotten rid of that big green cudaStreamSynchronize from earlier. no idea why it was taking so long. also there's a chance ive introduced a race condition somewhere now, so there may be more work to make this "stream-safe"

flops profiler seems to fluctuate by a couple tflops from run to run, it only takes from a single step wonder if there is a way to have it aggregate over entire run

GPU util is still pretty low

image

jfc4050 commented 3 years ago

this was the line causing the 2+ms cuda synchronize, pretty surprised it was blocking for that long

https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/partition_parameters.py#L669

jfc4050 commented 3 years ago

taking a look at another expensive synchronization here https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/partition_parameters.py#L798

don't feel confident that this can be safely removed yet, will spend some time investigating the code to see. did a quick run anyways to get an idea of the theoretical gains we could get from optimizing this. got about 5 TFLOPS out of it.

-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step:                   10
Number of parameters:           394
Number of multiply-accumulate operations (MACs):   8267.5 G
Number of floating point operations ( = 2 * MACs):   16534.99 G
Latency:                        535.9 ms
Floating point operations per second(FLOPS):   30.85 TFLOPS

with synchronization in fetch image

without synchronization in fetch image

jfc4050 commented 3 years ago

found an unexpected optimization. was looking at the last profile and was a bit confused by the big (300-1000+us) gap between synchronize_communication and the end of fetch_sub_module, especially since there's very little code between that call and the end of the function. (code ref - https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L440-L446)

Was also kind of confused by the cuda API calls (reduce, memcpyAsync) happening at the same time. Turns out we make a call to .norm() to generate the debug string passed to print_rank_0. commented it out and got another TFLOP or two, GPU also got a bit busier after that.

-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step:                   10
Number of parameters:           394
Number of multiply-accumulate operations (MACs):   8267.5 G
Number of floating point operations ( = 2 * MACs):   16534.99 G
Latency:                        489.44 ms
Floating point operations per second(FLOPS):   33.78 TFLOPS

more generally, deepspeed codebase is covered pretty densely with these print_rank_0 and see_memory_usage calls, where it is passed some sort of debug message. there's usually a statement at the beginning of the function to return immediately if we dont want to print anything but by then we've already lost - we've already allocated memory for and constructed the string, and also in this case performed expensive tensor operations. Wonder if there are more optimization opportunities like this elsewhere

image

jfc4050 commented 3 years ago

just noticed that we arent prefetching anything but param 0, not sure if this is anything specific to our use case.

if i enable this print statement i get a bunch of them. https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L212-L215

i get this output after changing this bit of code https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L376-L378 to only print if prefetch_numel > 0 (and adding some more info to the debug message)

[2021-05-25 16:27:32,188] [INFO] [stage3.py:39:print_rank_0] ----PreFetch: {'prefetch_numel': 141546240, 'next_submodule_numel': 0, 'params_prefetched': [0], 'total_available_numel': 141546240, 'max_limit': 1000000000}
[2021-05-25 16:27:34,715] [INFO] [stage3.py:2704:_overflow_clean_up] [deepscale] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 2147483648.0, reducing to 1073741824.0
[2021-05-25 16:27:34,716] [INFO] [logging.py:60:log_dist] [Rank 0] rank=0 time (ms) | forward_microstep: 953.39 | backward_microstep: 1573.24 | backward_inner_microstep: 1562.73 | backward_allreduce_microstep: 10.45 | step_microstep: 2.66
[2021-05-25 16:27:34,716] [INFO] [logging.py:60:log_dist] [Rank 0] rank=0 time (ms) | forward: 953.36 | backward: 1573.22 | backward_inner: 1562.69 | backward_allreduce: 10.43 | step: 2.64

taking a look through PrefetchCoordinator to see if i can find a root cause

stephenrawls commented 3 years ago

@tjruwase any thoughts on this? Is anyone on DeepSpeed team able to help us understand why stage 3 parameter pre-fetching seems to not be working for our huggingface model?

tjruwase commented 3 years ago

@stephenrawls, apologies for the delay on our side in investigating this issue. I will take a look today.

jfc4050 commented 3 years ago

from what i can tell

tjruwase commented 3 years ago

@stephenrawls, can you please share your command-line?

stephenrawls commented 3 years ago

Yes, @jfc4050 provided it earlier, here is the reproduction script: https://gist.github.com/jfc4050/ab1ecddc9c13290ecc1dde4a4eee0358

tjruwase commented 3 years ago

@stephenrawls, thanks for your response. Yes, I am able to run @jfc4050 repro and see some regression between zero2 and zero3. But I am also curious about your RobertaForMaskedLM, since that is a real model. Thanks.

stephenrawls commented 3 years ago

That repro code also uses the RobertaForMaskedLM model class.

If you are just asking for what model dimensions I benchmarked with, I was using this, which is just a 10B parameter model with dimensions roughly inspired by GPT3-13B model dimensions:

    max_position_embeddings: 514
    type_vocab_size: 1
    num_attention_heads: 40
    num_hidden_layers: 30
    hidden_size: 5120
    intermediate_size: 20480
    gradient_checkpointing: true

As for my command line, I am using a custom training script that integrates ignite and deepspeed and hugginface. The repro script that @jfc4050 shared is very similar, but condensed a bit and with many things hard-coded to make it easier for you to run.

tjruwase commented 3 years ago

@stephenrawls, thanks that is all I needed. The dimensions are important because zero-3 is targeted for very large models. And, so some regression on smaller models that zero-2 can handle is acceptable. I will continue the investigation.

stephenrawls commented 3 years ago

(For reference, I would like to eventually benchmark performance on models larger than 10B parameters, but I am getting cpu out-of-memory errors when using stage 3 on models bigger than 20B parameters, even with NVME offloading turned on. Will file GitHub issue for this OOM problem separately)

tjruwase commented 3 years ago

@stephenrawls or @jfc4050, can you please share this snippet of your log?

-------------------------- DeepSpeed Flops Profiler --------------------------
Summary of forward pass:
Profile step:                   10
Number of parameters:           394
Number of multiply-accumulate operations (MACs):   8267.5 G
Number of floating point operations ( = 2 * MACs):   16534.99 G
Latency:                        593.05 ms
Floating point operations per second(FLOPS):   27.88 TFLOPS
tjruwase commented 3 years ago

Never mind, I see these snippets above. I noticed the profiler was not counting the number of parameters correctly for zero3 (394 instead of 2020M).

tjruwase commented 3 years ago

@stephenrawls and @jfc4050, is it possible for you to try out PR #1170?

jfc4050 commented 3 years ago

hey @tjruwase, thanks for the PR! I wasn't able to look at this for a couple of weeks, and something seems to have changed with my setup between then and now. I'm getting better performance over baseline with your PR, but both are much worse than I was seeing when i last measured. Will see if i can figure out why tomorrow.

AFAICT the PR fixes prefetching by fixing hook execution order and removes the unnecessary .norm() call. Do you have any thoughts about removing the two synchronization points i mentioned above? I was able to get significantly better performance by changing those on top of the two fixes you implemented in your PR but its a bit more difficult to verify that those aren't affecting correctness

tjruwase commented 3 years ago

@jfc4050, thanks for your response and help.

Yes, you have read the PR correctly. I have not removed the synchronization points as they require a bit more analysis/testing. Please share what you find out about reproducing the original perf that you saw. Anyways, I consider the PR to be a work-in-progress at this point. I hope that with help from you and @stephenrawls, we can get a better understanding of the perf blockers in zero-3 relative to zero-2.

By the way, I added some timers for the zero-3 fetching and prefetching, which you can enable by setting wall_clock_breakdown to true. Unfortunately, the timers themselves introduce synch points and so do affect timing somewhat.

jfc4050 commented 3 years ago

oh just realized i forgot to post this before i left for two weeks. the actual all_gather calls are surprisingly slow (its the non-blocking all_gather), i'm not sure how much we can do about this since its a direct call to a pytorch function. This really crops up when there are several all gather calls in sequence for module fetches. Since its done synchronously with submodule.forward it delays sending of work to the GPU and results in suboptimal utilization.

sync-fetch

zarzen commented 3 years ago

I have similar findings. all_gathers for one parameter takes about 0.5ms to finish. And most of the all_gather operations are very small, e.g., all_gather 320 elements for each partition.

image

The pre_launch typically takes 130us, the launch takes 250us about, and there is a torch.cuda.synchronize() sync point when pre_launch next all_gather operation.

One way I can think of is using batching, to reduce the number of all_gather ops.

tjruwase commented 3 years ago

@jfc4050 and @zarzen, thanks for sharing your insights, these are very valuable. I will resume this line next week, as I am off this week.

zarzen commented 3 years ago

https://github.com/microsoft/DeepSpeed/blob/b3870363e026717752c119b21870e2886b9ad92d/deepspeed/runtime/zero/stage3.py#L435 Changing this line to the following is going to improve the forward performance. As it is going to launch a single all-gather op for a list of params.

 self._all_gather(partitioned_params, async_op=False) 
jfc4050 commented 3 years ago

@zarzen sorry i don't completely follow, how would changing it to a blocking call result in the allgather being batched?

zarzen commented 3 years ago

@jfc4050 by checking the code path, the async_op=False is going to invoke _allgather_params function at here: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/partition_parameters.py#L847 In which, the function only invokes all_gather once with a large tensor.

(here is the code for calling _allgather_params https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/partition_parameters.py#L643_L646)

BTW, I just submited draft PR for improve the allgather_params op, #1188

jfc4050 commented 3 years ago

oh thanks totally missed it, interesting how it would do that. lemme try pulling in your change

zarzen commented 3 years ago

Another thing to notice, async_op=True is not going to give you real async all-gathers. This is because the torch.cuda.synchronize() at here

jfc4050 commented 3 years ago

thanks @zarzen for the insight and the idea to batch the all gather calls, its really promising. I think it would still be good to have an async batched all gather. I'm working on some ideas for optimizing prefetching, still working on testing/analysis and fixing some issues but so far performance seems pretty good (~37 TFLOPS/GPU compared to the ~16TFLOPS/GPU i started with)

https://github.com/jfc4050/DeepSpeed/commit/319e2af4d4516c38ed7513d3f5c9456a0f1fb046 EDIT: broke this link by rebasing, use this instead https://github.com/jfc4050/DeepSpeed/tree/stage3

zarzen commented 3 years ago

@jfc4050 interesting work on the prefetching. For the async all_gather, I think simply using handle.wait() could give you wrong result on allgather.

jfc4050 commented 3 years ago

do you mean for the batched all gather or just in general?

zarzen commented 3 years ago

In general.

The handle.wait is only working to block for completion of the job, when we set the envirnoment variable NCCL_BLOCKING_WAIT=1.

The underlying function of handle.wait() calls this function: https://github.com/pytorch/pytorch/blob/44922f26f5b1d7b6ee8d217a9d32bc0e40ec47a6/torch/lib/c10d/ProcessGroupNCCL.cpp#L316_L320

Which checks a condition of the variableblockingWait_, which is by default False. Thus, it basically skips all following condition checkings/synchronizations. To enable this branch, you have to set the NCCL_BLOCKING_WAIT=1.

jfc4050 commented 3 years ago

i see, my thought was that for correctness we don't need the actual allgather to complete before handle.wait() returns as long as the allgather result is being consumed in the same cuda stream. does that seem reasonable?

zarzen commented 3 years ago

yes, that's true. while the stream for collective communication is out of your control. Using with torch.cuda.stream(some_stream) is not going to work as expected. The underlying implementation of allgather has its own cuda stream.

jfc4050 commented 3 years ago

btw, this conversation made me notice that i forgot to take the allgather out of the separate stream (it was in a with torch.cuda.stream(self.comm_stream), went ahead and did that and performance dropped a little bit to ~38TFLOPS/GPU from ~40TFLOPS/GPU (i accidentally underreported earlier).

here's a profiler shot after doing that (this is from the first forward pass where there's no prefetching)

Screen Shot 2021-06-28 at 7 23 43 PM

so its consistent with what you said about allgather using its own stream, but it also appears like there's some synchronization going on regardless - the kernel operations in default stream are not starting until the allgathers have completed. this is consistent with the pytorch documentation https://pytorch.org/docs/stable/distributed.html and seems like it would be safe

wait() - in the case of CPU collectives, will block the process until the operation is completed. In the case of CUDA collectives, will block until the operation has been successfully enqueued onto a CUDA stream and the output can be utilized on the default stream without further synchronization.

here's a later forward pass from the same run where prefetching is happening

Screen Shot 2021-06-28 at 7 52 23 PM
zarzen commented 3 years ago

@jfc4050 Thanks for sharing the profiling results. Are they profiled at V100s?

I have tried the way described in the document, while still getting wrong results. When profiling the following function, the las part the with torch.cuda.stream ... cannot block the function.

def _torch_allgather_once(output_tensors,
                          input_tensors,
                          partition_sizes,
                          rank,
                          world_size):
    """"""
    s = torch.cuda.Stream()
    handles = []
    for part_idx, part_size in enumerate(partition_sizes):
        output_t = output_tensors[part_idx]
        input_t = input_tensors[part_idx]

        output_list = []
        for i in range(world_size):
            out_tensor = output_t.narrow(0, i * part_size, part_size)
            output_list.append(out_tensor)

        h = dist.all_gather(output_list, input_t, async_op=True)
        handles.append(h)

    handles[-1].wait()
    # torch.cuda.synchronize()
    with torch.cuda.stream(s):
        s.wait_stream(torch.cuda.default_stream())
        # output_list[-1].add_(1)
        return 
jfc4050 commented 3 years ago

yep, im running on single host with 8 V100s.

i think its expected that the s.wait_stream(torch.cuda.default_stream()) doesn't block the function, its just saying that no more work on stream s can happen until the work on the default stream completes. But even if you did a torch.cuda.default_stream().synchronize() the function still wouldnt be blocked on the allgather since it independently decides to use a different stream.

looking more closely at the profiler, there's cudaStreamWaitEvent calls that are used to block further work on the default stream until the allgathers have completed

they come from the ProcessGroupNCCL::WorkNCCL::wait -> ProcessGroupNCCL::WorkNCCL::synchronizeInternal call chain linked earlier, which leads to these calls: https://github.com/pytorch/pytorch/blob/44922f26f5b1d7b6ee8d217a9d32bc0e40ec47a6/torch/lib/c10d/ProcessGroupNCCL.cpp#L306-L312 https://github.com/pytorch/pytorch/blob/10b929bbfb288b5214f5a8998043b9dbf44cd2f4/aten/src/ATen/cuda/CUDAEvent.h#L120-L127 so i believe this means that once .wait() is called on an async all gather (or presumably any other NCCL-backed collective), no more work can happen on the default stream until the all gather has completed.

zarzen commented 3 years ago

so i believe this means that once .wait() is called on an async all gather (or presumably any other NCCL-backed collective), no more work can happen on the default stream until the all gather has completed.

do you mean the wait() simply block the cuda-kernels from launch at current_stream? (as here it uses https://github.com/pytorch/pytorch/blob/44922f26f5b1d7b6ee8d217a9d32bc0e40ec47a6/torch/lib/c10d/ProcessGroupNCCL.cpp#L308 )

anyway, this design is unintuitive to me ... it is pretty hard to get the status of a collective-communication operation, without explicity synchronization ( torch.cuda.synchronize() ).

zarzen commented 3 years ago

I mean you can mistakenly launch other operations on a customized cuda stream.

jfc4050 commented 3 years ago

so i believe this means that once .wait() is called on an async all gather (or presumably any other NCCL-backed collective), no more work can happen on the default stream until the all gather has completed.

do you mean the wait() simply block the cuda-kernels from launch at current_stream? (as here it uses https://github.com/pytorch/pytorch/blob/44922f26f5b1d7b6ee8d217a9d32bc0e40ec47a6/torch/lib/c10d/ProcessGroupNCCL.cpp#L308 )

yes sorry youre right its current_stream, in my specific case current_stream was the default stream

anyway, this design is unintuitive to me ... it is pretty hard to get the status of a collective-communication operation, without explicity synchronization ( torch.cuda.synchronize() ).

yeah agreed, afaict if i have a handle from a NCCL collective, i don't have a way to tell if the work is done, and if i perform any work on the results anywhere but current_stream without explicit synchronization (even after calling .wait() i get race conditions. However, if we do frequent cross-stream synchronizations it unnecessarily prevents efficient overlapping of computation and communication.

side note: from going through the torch.distributed docs a little more, turns out you actually don't get any guarantees with async_op=False either, so really you will have the exact same risks either way unless either explicitly synchronize all the time or take care when switching streams

Synchronous operation - the default mode, when async_op is set to False. When the function returns, it is guaranteed that the collective operation is performed. In the case of CUDA operations, it is not guaranteed that the CUDA operation is completed, since CUDA operations are asynchronous. For CPU collectives, any further function calls utilizing the output of the collective call will behave as expected. For CUDA collectives, function calls utilizing the output on the same CUDA stream will behave as expected. Users must take care of synchronization under the scenario of running under different streams. For details on CUDA semantics such as stream synchronization, see CUDA Semantics. See the below script to see examples of differences in these semantics for CPU and CUDA operations.

jfc4050 commented 3 years ago

afaict if i have a handle from a NCCL collective, i don't have a way to tell if the work is done

i take this back, looks like there is a way to get a cuda event out of a handle, it would then be possible to query for completion.

https://pytorch.org/docs/stable/generated/torch.cuda.Event.html#torch.cuda.Event

zarzen commented 3 years ago

I guess the ipc-handle is different from the handle returned from torch.distributed... APIs. Because the returned handle is a wrapper for the class ProcessGroup::Work.

jfc4050 commented 3 years ago

quick note: c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather_coalesced is stubbed out in PyTorch - @zarzen and I both have DeepSpeed/Python implementations of this but could benefit everyone to move the implementation to Pytorch/C++ at some point

https://github.com/pytorch/pytorch/blob/44922f26f5b1d7b6ee8d217a9d32bc0e40ec47a6/torch/lib/c10d/ProcessGroupNCCL.cpp#L1312-L1318

zarzen commented 3 years ago

thanks @zarzen for the insight and the idea to batch the all gather calls, its really promising. I think it would still be good to have an async batched all gather. I'm working on some ideas for optimizing prefetching, still working on testing/analysis and fixing some issues but so far performance seems pretty good (~37 TFLOPS/GPU compared to the ~16TFLOPS/GPU i started with)

jfc4050@319e2af EDIT: broke this link by rebasing, use this instead https://github.com/jfc4050/DeepSpeed/tree/stage3

  • remove torch.cuda.synchronize() calls, make it so each module fetch can be blocked on individually
  • added all_gather_coalesced, which is similar to the batched all gather you implemented but can by async. I'm missing some of the optimizations you proposed related to avoiding memcpys. The parameters in each module are fetched in a single all_gather_coalesced call
  • fixed parameter persistence issues - persistent parameters were actually getting partitioned anyways, there was a release call that was ignoring persistence
  • using queue for prefetching which can be computed once instead of determining which modules to prefetch at each step
  • change fetching order from (fetch current, block, prefetch) -> (fetch current, prefetch, block) so prefetches can progress while we are waiting for current module
  • minor changes to module re-use logic

@jfc4050 I would like to try the prefetching you have modified. I see a runtime error, in running bing_bert model,

KeyError: 13
    param.ds_active_sub_modules.remove(submodule.id)
    hidden_states = layer_module(hidden_states, attention_mask)KeyError
:   File "/usr/local/lib64/python3.7/site-packages/torch/nn/modules/module.py", line 748, in _call_impl
13
    hook_result = hook(self, input, result)
  File "/home/ec2-user/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 1332, in _post_forward_module_hook
    self.post_sub_module_forward_function(module)
  File "/home/ec2-user/DeepSpeed/deepspeed/utils/nvtx.py", line 9, in wrapped_fn
    return func(*args, **kwargs)
  File "/home/ec2-user/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 1414, in post_sub_module_forward_function
    self.param_coordinator.release_sub_module(sub_module)
  File "/home/ec2-user/DeepSpeed/deepspeed/utils/nvtx.py", line 9, in wrapped_fn
    return func(*args, **kwargs)
  File "/home/ec2-user/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 330, in release_sub_module
    param.ds_active_sub_modules.remove(submodule.id)

any suggestion for a quick fix?

jfc4050 commented 3 years ago

any suggestion for a quick fix?

for now try changing it to a .discard(...), there may be some issue that didnt appear in the roberta model im using for testing

zarzen commented 3 years ago

@jfc4050 Thanks for the suggestion! I found another small issue is that we need a param.ds_status check in the wait function of the class AllGatherCoalescedHandle. I found there might be some param.ds_status is AVAILABLE.

( before this line: https://github.com/jfc4050/DeepSpeed/blob/stage3/deepspeed/runtime/zero/partition_parameters.py#L345

if param.ds_status != ZeroParamStatus.INFLIGHT:
    continue

)

But the code cannot work with checkpoint_activations = True in bing_bert model (Modify this line: https://github.com/zarzen/DeepSpeedExamples/blob/edb715997239270201f1c0d12936ee610a53a818/bing_bert/nvidia/modelingpreln.py#L1129 ).

In which setup, I got the following error

  0%|          | 9/152966 [00:17<81:39:41,  1.92s/it]
  0%|          | 9/152942 [00:17<81:53:24,  1.93s/it]

Traceback (most recent call last):
  File "/home/ec2-user/DeepSpeedExamples/bing_bert/zero_opt_experiments/scripts/../../deepspeed_train.py", line 607, in <module>
Traceback (most recent call last):
  File "/home/ec2-user/DeepSpeedExamples/bing_bert/zero_opt_experiments/scripts/../../deepspeed_train.py", line 607, in <module>
    main()
  File "/home/ec2-user/DeepSpeedExamples/bing_bert/zero_opt_experiments/scripts/../../deepspeed_train.py", line 600, in main
    main()
  File "/home/ec2-user/DeepSpeedExamples/bing_bert/zero_opt_experiments/scripts/../../deepspeed_train.py", line 600, in main
    run(args, model, optimizer, start_epoch)
  File "/home/ec2-user/DeepSpeedExamples/bing_bert/zero_opt_experiments/scripts/../../deepspeed_train.py", line 566, in run
    run(args, model, optimizer, start_epoch)
  File "/home/ec2-user/DeepSpeedExamples/bing_bert/zero_opt_experiments/scripts/../../deepspeed_train.py", line 566, in run
    train(args, index, model, optimizer, pretrain_dataset_provider)
  File "/home/ec2-user/DeepSpeedExamples/bing_bert/zero_opt_experiments/scripts/../../deepspeed_train.py", line 179, in train
    train(args, index, model, optimizer, pretrain_dataset_provider)
  File "/home/ec2-user/DeepSpeedExamples/bing_bert/zero_opt_experiments/scripts/../../deepspeed_train.py", line 179, in train
    model.network.backward(loss)
  File "/home/ec2-user/DeepSpeed/deepspeed/utils/nvtx.py", line 9, in wrapped_fn
    model.network.backward(loss)
  File "/home/ec2-user/DeepSpeed/deepspeed/utils/nvtx.py", line 9, in wrapped_fn
    return func(*args, **kwargs)
  File "/home/ec2-user/DeepSpeed/deepspeed/runtime/engine.py", line 1191, in backward
    return func(*args, **kwargs)
  File "/home/ec2-user/DeepSpeed/deepspeed/runtime/engine.py", line 1191, in backward
    self.optimizer.backward(loss)
  File "/home/ec2-user/DeepSpeed/deepspeed/utils/nvtx.py", line 9, in wrapped_fn
    self.optimizer.backward(loss)
  File "/home/ec2-user/DeepSpeed/deepspeed/utils/nvtx.py", line 9, in wrapped_fn
    return func(*args, **kwargs)
  File "/home/ec2-user/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 2811, in backward
    return func(*args, **kwargs)
  File "/home/ec2-user/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 2811, in backward
        self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)

  File "/home/ec2-user/DeepSpeed/deepspeed/runtime/fp16/loss_scaler.py", line 53, in backward
  File "/home/ec2-user/DeepSpeed/deepspeed/runtime/fp16/loss_scaler.py", line 53, in backward
        scaled_loss.backward(retain_graph=retain_graph)scaled_loss.backward(retain_graph=retain_graph)

  File "/usr/local/lib64/python3.7/site-packages/torch/tensor.py", line 233, in backward
  File "/usr/local/lib64/python3.7/site-packages/torch/tensor.py", line 233, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

  File "/usr/local/lib64/python3.7/site-packages/torch/autograd/__init__.py", line 147, in backward
  File "/usr/local/lib64/python3.7/site-packages/torch/autograd/__init__.py", line 147, in backward
        allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flagallow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

RuntimeErrorRuntimeError: : The size of tensor a (0) must match the size of tensor b (2560) at non-singleton dimension 1The size of tensor a (0) must match the size of tensor b (2560) at non-singleton dimension 1
jfc4050 commented 3 years ago

thanks @zarzen, i pushed a fix to the first issue you found.

regarding the second.. thinking about it more, i think i shouldn't have changed it to batch the parameters by module - this would be suboptimal most of the time because we can't control the number of bytes per allgather to saturate communication bandwidth. I'm going to change it back to not batch by module, and user can change prefetch_bucket_size depending on if they are using PCIe/NVLink/etc.

I suspect that the second issue will go away when i do that, but not entirely sure why it only happens when we are checkpointing activations

zarzen commented 3 years ago

@tjruwase why do you think the batching is going to affect the checkpointing?

I think the origin zero-stage3, the one before your changes, also do module level batching. As the prefetch function is invoked when the forward hook is trigged.