turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.71k stars 283 forks source link

Want to try row split + all_reduce for MLP and attn #614

Open Azure-Tang opened 2 months ago

Azure-Tang commented 2 months ago

We are trying to using all reduce TP to slash the communication time. I noticed that you have implemented Row split + all_reduce for MLP (not faster, disabled). Why this version is abandoned? By row split we only need 2 communications(all reduce) per layer. But the release version use 6 communications per layer using all gather. Can you share the row split version that I can test/modify?

Anyway, I noticd that your ExLlamaV2Linear implemented tp_split_row and forward_tp_row.

So I tried to modify ExLlamaV2MLP.tp_split and ExLlamaV2MLP.forward_tp_old to implement an all reduced version tp, but I got segment fault in down_proj’s cpp kernel.


    def forward_tp_old(
        self,
        hidden_states: torch.Tensor,
        cache = None,
        attn_params = None,
        past_len = None,
        intermediates: bool = False,
        loras: list[ExLlamaV2Lora] | None = None,
        **kwargs
    ) -> torch.Tensor | dict[str: torch.Tensor]:

        cfg = self.model.config

        batch_size, q_len, _ = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
        hidden_states = self.model.tp_context.broadcast(0, hidden_states, BROADCAST_ID)

        residual = hidden_states

        post_norm = self.pre_layernorm.forward_tp(hidden_states, output_split = True) \
            if self.pre_layernorm else hidden_states

        gate = self.gate_proj.forward_tp(post_norm, output_split = True)
        up = self.up_proj.forward_tp(post_norm, output_split = True)

        outputs = []
        for idx, hs in enumerate(post_norm):
            dev = hs.device.index
            context = self.model.get_device_context(dev)
            torch.cuda.set_stream(context.stream)

            if cfg.arch.mlp_act_func == "silu":
                output = F.silu(gate[idx])
            elif cfg.arch.mlp_act_func == "gelu":
                output = F.gelu(gate[idx], approximate = "tanh")
            output *= up[idx]
            # output.clamp_(min = -65504.0, max = 65504.0)
            outputs.append(output)

        # outputs = self.model.tp_context.allgather(1, outputs, BROADCAST_ID, BROADCAST_ID) # @@@@ I canced this all gather

        down = self.down_proj.forward_tp_row(outputs, output_split = True) # segment fault in `forward_tp_row` kernel.

 # @@@@ maybe all reduce here

        if self.has_residual:
            self.model.tp_context.add_residual(down, residual, BROADCAST_RS)

        down = self.model.tp_context.gather(0, down, BROADCAST_RS)
        down = down.view(batch_size, q_len, down.shape[-1])
        return down
    def tp_split(self):

        cfg = self.model.config
        ctx = self.model.tp_context

        if self.pre_layernorm is not None:
            self.pre_layernorm.tp_split(BROADCAST_RS)
        if self.post_layernorm is not None:
            self.post_layernorm.tp_split(BROADCAST_RS)
        if self.gate_proj is not None:
            self.gate_proj.tp_split(BROADCAST_ID)
        if self.up_proj is not None:
            self.up_proj.tp_split(BROADCAST_ID)
        if self.down_proj is not None:
            self.down_proj.tp_split_row(BROADCAST_ID, BROADCAST_ID) # changed from tp_split to tp_split_row

        maxrows = cfg.max_batch_size * cfg.max_input_len
        dtype = torch.half

        ctx.begin_scratch_alloc_tp()
        ctx.reserve_scratch(self.tp_dq_size)
        self.temp_bc0 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_RS)
        self.temp_bc1 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_RS)
        self.temp_bc2 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_ID)
        self.temp_gate = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_ID)
        self.temp_up = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_ID)
        self.temp_down = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_RS)

        self.is_tp = True

The error message from gdb is as below:

(gdb) bt
#0  0x00007feffe68689d in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#1  0x00007ff0ac613d52 in ?? () from /opt/conda/lib/python3.10/site-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
#2  0x00007ff0ac64e5fe in cudaEventElapsedTime () from /opt/conda/lib/python3.10/site-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
#3  0x00007fef45a9f188 in gemm_half_q_half_cuda_part (stream=stream@entry=0xcb42ac0, a=a@entry=0x7fedc9e07600, b=b@entry=0x2dd65520, c=c@entry=0x7fedc9e08e00, size_m=size_m@entry=1, 
    size_n=size_n@entry=11008, size_k=3072, m_count=1, clear=true, r_weights=0x0, r_weights_stride=0, mul_r_weights=false, graph=0x0, label=0)
    at /root/azure/exllamav2/exllamav2/exllamav2_ext/cuda/q_gemm.cu:149
#4  0x00007fef45a9f3a1 in gemm_half_q_half_cuda (stream=0xcb42ac0, cublas_handle=0x316f4730, a=0x7fedc9e07600, b=0x2dd65520, c=0x7fedc9e08e00, size_m=1, size_n=11008, size_k=3072, 
    clear=true, temp_dq=0x0, force_cuda=false, r_weights=0x0, r_weights_stride=0, mul_r_weights=false, graph=0x0, label=0)
    at /root/azure/exllamav2/exllamav2/exllamav2_ext/cuda/q_gemm.cu:293
#5  0x00007fef45acf54d in gemm_half_q_half_tp(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, unsigned long, int) () from /opt/conda/lib/python3.10/site-packages/exllamav2_ext.cpython-310-x86_64-linux-gnu.so
#6  0x00007fef45abcffb in pybind11::cpp_function::initialize<void (*&)(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, unsigned long, int), void, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, unsigned long, int, pybind11::name, pybind11::scope, pybind11::sibling, char [20]>(void (*&)(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, unsigned long, int), void (*)(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, unsigned long, int), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, char const (&) [20])::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) ()
   from /opt/conda/lib/python3.10/site-packages/exllamav2_ext.cpython-310-x86_64-linux-gnu.so
#7  0x00007fef45a304f8 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /opt/conda/lib/python3.10/site-packages/exllamav2_ext.cpython-310-x86_64-linux-gnu.so
#8  0x00000000004fc697 in cfunction_call (func=0x7fef52dd79c0, args=<optimized out>, kwargs=<optimized out>) at /usr/local/src/conda/python-3.10.13/Objects/methodobject.c:543
#9  0x00000000004f614b in _PyObject_MakeTpCall (tstate=0x1201e40, callable=0x7fef52dd79c0, args=<optimized out>, nargs=<optimized out>, keywords=0x0)
    at /usr/local/src/conda/python-3.10.13/Objects/call.c:215
--Type <RET> for more, q to quit, c to continue without paging--q
turboderp commented 2 months ago

Row-parallel doesn't work for attn because the attn output projection is shuffled and it would need a little extra work to unshuffle it so the input columns actually align with the rows of each slice of the split tensor. I've fixed this for the MLP so far, by applying the input permutation from the down projection to the up and gate tensors (which is a small optimization regardless of TP), meaning the down projection can be trivially split by row. It's definitely not abandoned, just disabled for now since it's incomplete.

For the record, it's not really six communications, but four:

attn:

The intermediate state between layers is the same pinned CPU tensor that exists in the middle of each all-gather anyway. This simplifies the implementation somewhat and allows attn and MLP layers to use different sets of devices in the case of uneven splits.

Everything could be reduced to two all-reduces in principle, or one all-reduce plus two all-gathers without a solution to the permutation issue during attn, but the reason I put it on hold and released the (experimental) feature in the state it's currently in is that I haven't yet found a way to do the all-reduce that's actually more efficient, not to mention user-friendly enough to be worth considering.

There are already solutions for people who want maximally-efficient inference on server hardware with 2^n identical, headless, P2P-enabled devices, appropriate BIOS settings for compute workloads, and so on. There's no point in trying to compete with NVIDIA for performance, but I think ExLlama can still offer flexibility and usability advantages, especially for affordable hardware. Libraries like NCCL come with really annoying limitations, like an inability to gather tensors of different shapes, incompatibility with Windows, and apparently a strong preference for multiprocessing. At least a single-process all-reduce with NCCL does not appear to be any faster than two all-gathers even when one of the latter is the MLP intermediate state which is 3x as large as the residual stream.

Obviously P2P would help, but for consumer hardware that currently comes down to the hacked NVIDIA driver that geohotz doesn't seem to be too interested in outside the context of TinyGrad (which is perfectly fair, to be clear).

So, where I'm currently at, I'm experimenting, whenever I have time, to come up with a more efficient all-reduce operation that, ideally, works within a single process (though multithreaded is fine of course) and doesn't completely tank performance without P2P or on slow PCIe links.

Azure-Tang commented 2 months ago

Thanks for answering!

So for MLP, I should be able to directly split down_proj row by tp_split_row and forward by forward_tp_row right? Why am i got segment fault now?

And you mentioned that:

Row-parallel doesn't work for attn because the attn output projection is shuffled and it would need a little extra work to unshuffle it so the input columns actually align with the rows of each slice of the split tensor.

How to do this one? I didn't find your mlp example on linear.py and want it so maybe I can do this to unshuffle o_proj myself until you got your time. I'm trying to build a special two gpu version TP with max inference speed, so two all reduce can be more efficient.

turboderp commented 2 months ago

Problem is it's not o_proj that needs to be unshuffled, it's the input to o_proj that needs to be pre-shuffled to match the permutation in o_proj so that this permutation can be deleted. This is easy enough to do for the MLP layers since it's just linear -> act_fun -> linear, so you can rearrange the output features of the first linear layer(s) and that'll suffice as the activation function works on independent elements. But for attention this would mean moving channels between attn heads and that would just break attention.

Changing to row split just in the MLP also requires a few other changes. There are two different versions of the forward pass, one in Python and one in C++, and they'd both have to be updated. The logic is probably easiest to follow in the (poorly named) forward_tp_old function which you could update to start with (just replace the call to forward_tp), then replicate the changes in ext_c.tp_mlp_forward_ afterwards.

It currently goes like this:

The all-reduce version would have to be:

The output could then be passed directly into the next attention layer, but that attention layer has to skip broadcasting the already-broadcast tensor (conditionally since the embedding layer would still output a single tensor in pinned memory). And that should be it.

Azure-Tang commented 2 months ago

The all-reduce version would have to be:

apply post_attention_layernorm (returns the broadcast tensor as a list) apply gate and up (creates two lists of column-split outputs) activation function apply row-split down projection (tensors are now full size partial sums of the result) add residual stream on one device all-reduce (results in broadcast state again)

This seems just what I did. I modified forward_tp_old but I got segment fault when applying row-split down projection. What did i missed?

Anyway, I noticd that your ExLlamaV2Linear implemented tp_split_row and forward_tp_row.

So I tried to modify ExLlamaV2MLP.tp_split and ExLlamaV2MLP.forward_tp_old to implement an all reduced version tp, but I got segment fault in down_proj’s cpp kernel.

    def forward_tp_old(
        self,
        hidden_states: torch.Tensor,
        cache = None,
        attn_params = None,
        past_len = None,
        intermediates: bool = False,
        loras: list[ExLlamaV2Lora] | None = None,
        **kwargs
    ) -> torch.Tensor | dict[str: torch.Tensor]:

        cfg = self.model.config

        batch_size, q_len, _ = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
        hidden_states = self.model.tp_context.broadcast(0, hidden_states, BROADCAST_ID)

        residual = hidden_states

        post_norm = self.pre_layernorm.forward_tp(hidden_states, output_split = True) \
            if self.pre_layernorm else hidden_states

        gate = self.gate_proj.forward_tp(post_norm, output_split = True)
        up = self.up_proj.forward_tp(post_norm, output_split = True)

        outputs = []
        for idx, hs in enumerate(post_norm):
            dev = hs.device.index
            context = self.model.get_device_context(dev)
            torch.cuda.set_stream(context.stream)

            if cfg.arch.mlp_act_func == "silu":
                output = F.silu(gate[idx])
            elif cfg.arch.mlp_act_func == "gelu":
                output = F.gelu(gate[idx], approximate = "tanh")
            output *= up[idx]
            # output.clamp_(min = -65504.0, max = 65504.0)
            outputs.append(output)

        # outputs = self.model.tp_context.allgather(1, outputs, BROADCAST_ID, BROADCAST_ID) # @@@@ I canced this all gather

        down = self.down_proj.forward_tp_row(outputs, output_split = True) # segment fault in `forward_tp_row` kernel.

 # @@@@ maybe all reduce here

        if self.has_residual:
            self.model.tp_context.add_residual(down, residual, BROADCAST_RS)

        down = self.model.tp_context.gather(0, down, BROADCAST_RS)
        down = down.view(batch_size, q_len, down.shape[-1])
        return down
    def tp_split(self):

        cfg = self.model.config
        ctx = self.model.tp_context

        if self.pre_layernorm is not None:
            self.pre_layernorm.tp_split(BROADCAST_RS)
        if self.post_layernorm is not None:
            self.post_layernorm.tp_split(BROADCAST_RS)
        if self.gate_proj is not None:
            self.gate_proj.tp_split(BROADCAST_ID)
        if self.up_proj is not None:
            self.up_proj.tp_split(BROADCAST_ID)
        if self.down_proj is not None:
            self.down_proj.tp_split_row(BROADCAST_ID, BROADCAST_ID) # changed from tp_split to tp_split_row

        maxrows = cfg.max_batch_size * cfg.max_input_len
        dtype = torch.half

        ctx.begin_scratch_alloc_tp()
        ctx.reserve_scratch(self.tp_dq_size)
        self.temp_bc0 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_RS)
        self.temp_bc1 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_RS)
        self.temp_bc2 = ctx.get_scratch_slice_tp_bc(maxrows, dtype, BROADCAST_ID)
        self.temp_gate = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_ID)
        self.temp_up = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_ID)
        self.temp_down = ctx.get_scratch_slice_tp(maxrows, dtype, BROADCAST_RS)

        self.is_tp = True

The error message from gdb is as below:

(gdb) bt
#0  0x00007feffe68689d in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1
#1  0x00007ff0ac613d52 in ?? () from /opt/conda/lib/python3.10/site-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
#2  0x00007ff0ac64e5fe in cudaEventElapsedTime () from /opt/conda/lib/python3.10/site-packages/torch/lib/../../nvidia/cuda_runtime/lib/libcudart.so.12
#3  0x00007fef45a9f188 in gemm_half_q_half_cuda_part (stream=stream@entry=0xcb42ac0, a=a@entry=0x7fedc9e07600, b=b@entry=0x2dd65520, c=c@entry=0x7fedc9e08e00, size_m=size_m@entry=1, 
    size_n=size_n@entry=11008, size_k=3072, m_count=1, clear=true, r_weights=0x0, r_weights_stride=0, mul_r_weights=false, graph=0x0, label=0)
    at /root/azure/exllamav2/exllamav2/exllamav2_ext/cuda/q_gemm.cu:149
#4  0x00007fef45a9f3a1 in gemm_half_q_half_cuda (stream=0xcb42ac0, cublas_handle=0x316f4730, a=0x7fedc9e07600, b=0x2dd65520, c=0x7fedc9e08e00, size_m=1, size_n=11008, size_k=3072, 
    clear=true, temp_dq=0x0, force_cuda=false, r_weights=0x0, r_weights_stride=0, mul_r_weights=false, graph=0x0, label=0)
    at /root/azure/exllamav2/exllamav2/exllamav2_ext/cuda/q_gemm.cu:293
#5  0x00007fef45acf54d in gemm_half_q_half_tp(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, unsigned long, int) () from /opt/conda/lib/python3.10/site-packages/exllamav2_ext.cpython-310-x86_64-linux-gnu.so
#6  0x00007fef45abcffb in pybind11::cpp_function::initialize<void (*&)(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, unsigned long, int), void, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, unsigned long, int, pybind11::name, pybind11::scope, pybind11::sibling, char [20]>(void (*&)(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, unsigned long, int), void (*)(std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::vector<unsigned long, std::allocator<unsigned long> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, bool, unsigned long, int), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, char const (&) [20])::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) ()
   from /opt/conda/lib/python3.10/site-packages/exllamav2_ext.cpython-310-x86_64-linux-gnu.so
#7  0x00007fef45a304f8 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from /opt/conda/lib/python3.10/site-packages/exllamav2_ext.cpython-310-x86_64-linux-gnu.so
#8  0x00000000004fc697 in cfunction_call (func=0x7fef52dd79c0, args=<optimized out>, kwargs=<optimized out>) at /usr/local/src/conda/python-3.10.13/Objects/methodobject.c:543
#9  0x00000000004f614b in _PyObject_MakeTpCall (tstate=0x1201e40, callable=0x7fef52dd79c0, args=<optimized out>, nargs=<optimized out>, keywords=0x0)
    at /usr/local/src/conda/python-3.10.13/Objects/call.c:215
--Type <RET> for more, q to quit, c to continue without paging--q