vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
29.68k stars 4.48k forks source link

[Misc]: Memory Order in Custom Allreduce #8404

Closed HydraQYH closed 1 month ago

HydraQYH commented 1 month ago

Memory Order in Custom Allreduce

In custom allreduce, i notice that Signal* has a volatile qualifier. And there are no memory fence in start_sync function. I want to know that can volatile will make right memory order? The start_sync program order is:

  1. set start flag to other GPU's Signal
  2. read start flag from local GPU's Signal
  3. allreduce(pull data from other GPU)

In my opinion, without memory fence, the step 3 may be visible before Step 2 or 1.

youkaichao commented 1 month ago

this is a great observation!

@kanghui0204 also fond this problem. it seems adding __threadfence_system can solve the problem, but at significant performance cost.

@HydraQYH do you have any ideas on how to solve it?

also cc @hanzhi713 if you still have bandwidth to investigate.

hanzhi713 commented 1 month ago

I don't think this is the issue. Step 3 will be executed after step 1 and 2 due to __syncthreads(), which is also a memory fence. During the execution of step 3, all GPUs should've at least entered the custom allreduce kernel (otherwise we'll be stuck at step 2), which means data is ready.

Even if step 3 got the wrong data, it won't cause a hang. If it hangs it must occur in one of the while loops.

youkaichao commented 1 month ago

@hanzhi713 adding __threadfence_system() in https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L137

seems to work, a solution found by @kanghui0204

I don't know if we can use some weaker sync op here, __threadfence_system might be too conservative.

hanzhi713 commented 1 month ago

@youkaichao Can you try what I proposed in the second bullet point of #8410?

I think the rationale behind this (I'm thinking about this too) is that the end reset https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L135 got reordered after https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L162 causing an indefinite wait.

If this is indeed the case, it should be fixed by changing https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L156 to a unconditional fence.

HydraQYH commented 1 month ago

I don't think this is the issue. Step 3 will be executed after step 1 and 2 due to __syncthreads(), which is also a memory fence. During the execution of step 3, all GPUs should've at least entered the custom allreduce kernel (otherwise we'll be stuck at step 2), which means data is ready.

Even if step 3 got the wrong data, it won't cause a hang. If it hangs it must occur in one of the while loops.

@hanzhi713 Thanks for reply. I also think about the syncthreads(). I'm not sure that if syncthreads() has a memory fence semantic. In CUDA programming guide, it just say: "__syncthreads() waits until all threads in the thread block have reached this point and all global and shared memory accesses made by these threads prior to __syncthreads() are visible to all threads in the block." So i make this issue.

youkaichao commented 1 month ago

@hanzhi713 which will be more efficient?

adding __threadfence_system here: https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L137

or

unconditionally use __threadfence_system here:

https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L156

hanzhi713 commented 1 month ago

I don't think this is the issue. Step 3 will be executed after step 1 and 2 due to __syncthreads(), which is also a memory fence. During the execution of step 3, all GPUs should've at least entered the custom allreduce kernel (otherwise we'll be stuck at step 2), which means data is ready. Even if step 3 got the wrong data, it won't cause a hang. If it hangs it must occur in one of the while loops.

@hanzhi713 Thanks for reply. I also think about the syncthreads(). I'm not sure that if syncthreads() has a memory fence semantic. In CUDA programming guide, it just say: "__syncthreads() waits until all threads in the thread block have reached this point and all global and shared memory accesses made by these threads prior to __syncthreads() are visible to all threads in the block." So i make this issue.

"... are visible to all threads in the block" this is a even stronger guarantee than a memory fence. Memory fence only guarantees ordering. This also guarantees visibility.

HydraQYH commented 1 month ago

@hanzhi713 adding __threadfence_system() in

https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L137

seems to work, a solution found by @kanghui0204

I don't know if we can use some weaker sync op here, __threadfence_system might be too conservative.

@youkaichao I have tried this. In my A100, it will cause about 6us latency.

I tried to change the code to use weaker memory fence just like TensorRT-LLM. It seems that it will cause about 1~3us latency. It is better than __threadfen_system(). But still not good than @hanzhi713 's original implement without memory fence.

I can make a code review for my plan.

hanzhi713 commented 1 month ago

@hanzhi713 which will be more efficient?

adding __threadfence_system here:

https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L137

or

unconditionally use __threadfence_system here:

https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L156

Second. It will add some latency to one stage allreduce, but two stage allreduce already has it, so overall impact is smaller.

HydraQYH commented 1 month ago

@hanzhi713 adding __threadfence_system() in https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L137

seems to work, a solution found by @kanghui0204 I don't know if we can use some weaker sync op here, __threadfence_system might be too conservative.

@youkaichao I have tried this. In my A100, it will cause about 6us latency.

I tried to change the code to use weaker memory fence just like TensorRT-LLM. It seems that it will cause about 1~3us latency. It is better than __threadfen_system(). But still not good than @hanzhi713 's original implement without memory fence.

I can make a code review for my plan.

TensorRT-LLM use both of fence(Acquire-Release) and __syncthreads: https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu#L172

@hanzhi713 So maybe using both of them is more robust?

kanghui0204 commented 1 month ago

I saw @youkaichao 's comment , I think the problem is problem

I don't think switch end and start with 0/1 is a good way , and I think the solution of below should be better , and don't need fence in one shot, how do you think? @hanzhi713 @HydraQYH

solution

hanzhi713 commented 1 month ago

@kanghui0204 Your solution seems reasonable. It's worth a shot to see the performance. Using increments removes the need to reset flags and race condition. I like the idea.

kanghui0204 commented 1 month ago

OK, I'll try it sometime later

HydraQYH commented 1 month ago

I saw @youkaichao 's comment , I think the problem is problem

I don't think switch end and start with 0/1 is a good way , and I think the solution of below should be better , and don't need fence in one shot, how do you think? @hanzhi713 @HydraQYH

solution

Very interesting! I guess that in two-cards scenario, it seems really good. How about the 4-cards or 8 cards? I'm exciting to seen the performance result.

kanghui0204 commented 1 month ago

I saw @youkaichao 's comment , I think the problem is problem I don't think switch end and start with 0/1 is a good way , and I think the solution of below should be better , and don't need fence in one shot, how do you think? @hanzhi713 @HydraQYH solution

Very interesting! I guess that in two-cards scenario, it seems really good. How about the 4-cards or 8 cards? I'm exciting to seen the performance result.

I think this works for all num of GPUs, because you can prepare a pair of flags for each other GPUs.

youkaichao commented 1 month ago

@kanghui0204 I think you only need one local flag regardless of gpus, but global flags increase as the number of gpus?

every gpu has a flag array bool flags[N], where flags[i] is gpu i's local flag, and the rest flags are global flags. all the flags from all gpus form one array bool all_flags[N][N] (can be shared via p2p, or can be host managed memory mapped to device).

every gpu concurrently execute the following:


    const int N = 4; // Number of GPUs
    int i = 0; // GPU index

    // Assuming all_flags is an N x N 2D array
    int all_flags[N][N] = {0}; // Initialize all elements to 0, and this array is shared across all gpus

    // Update flags for the current GPU
    all_flags[i][i] += 1;

    // Update flags for peer GPUs
    for (int j = 0; j < N; ++j) {
        if (j != i) {
            all_flags[j][i] += 1;
        }
    }

    // Wait until synchronization is achieved
    bool synced = false;
    while (!synced) {
        synced = true;
        for (int j = 0; j < N; ++j) {
            if (all_flags[i][j] != all_flags[i][i]) {
                synced = false;
                break; // No need to check further, already out of sync
            }
        }
    }

the diagram:

image

this essentially act as a barrier for all gpus.

HydraQYH commented 1 month ago

@hanzhi713 adding __threadfence_system() in

https://github.com/vllm-project/vllm/blob/0af3abe3d3225449c907d75eb3d2ae4b83bd21a1/csrc/custom_all_reduce.cuh#L137

seems to work, a solution found by @kanghui0204

I don't know if we can use some weaker sync op here, __threadfence_system might be too conservative.

cc@youkaichao https://github.com/vllm-project/vllm/issues/8410#issuecomment-2348359732

HydraQYH commented 1 month ago

Move to https://github.com/vllm-project/vllm/issues/8457

kanghui0204 commented 1 month ago

@kanghui0204 I think you only need one local flag regardless of gpus, but global flags increase as the number of gpus?

every gpu has a flag array bool flags[N], where flags[i] is gpu i's local flag, and the rest flags are global flags. all the flags from all gpus form one array bool all_flags[N][N] (can be shared via p2p, or can be host managed memory mapped to device).

every gpu concurrently execute the following:

    const int N = 4; // Number of GPUs
    int i = 0; // GPU index

    // Assuming all_flags is an N x N 2D array
    int all_flags[N][N] = {0}; // Initialize all elements to 0, and this array is shared across all gpus

    // Update flags for the current GPU
    all_flags[i][i] += 1;

    // Update flags for peer GPUs
    for (int j = 0; j < N; ++j) {
        if (j != i) {
            all_flags[j][i] += 1;
        }
    }

    // Wait until synchronization is achieved
    bool synced = false;
    while (!synced) {
        synced = true;
        for (int j = 0; j < N; ++j) {
            if (all_flags[i][j] != all_flags[i][i]) {
                synced = false;
                break; // No need to check further, already out of sync
            }
        }
    }

the diagram:

image

this essentially act as a barrier for all gpus.

yes I agree with you.

hanzhi713 commented 1 month ago

@kanghui0204 I can take a stab at this idea if you haven't started. I happen to have some time this week.

kanghui0204 commented 1 month ago

@kanghui0204 I can take a stab at this idea if you haven't started. I happen to have some time this week.

@hanzhi713 Sorry , I don't start it because Mid-autumn festival , if you have time , you can have a try , thanks , and happy Mid-autumn festival.

hanzhi713 commented 1 month ago

@kanghui0204 Sure. I will get started today. Happy holiday!