triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.31k stars 1.63k forks source link

How to leverage hardware intrinsics like convolution in triton #1011

Open sethbrin opened 1 year ago

sethbrin commented 1 year ago

Background

We are evaluating how to build a new backend on triton, But there is a problem, that is, there are some relatively coarse-grained instructions on our hardware, such as convolution, and there is no similar mechanism on the triton language.

The following is the conv instruction decription(refer to https://www.cambricon.com/docs/sdk_1.9.0/cntoolkit_3.1.4/cambricon_bang_c_4.1.3/2Builtin-Functions/Artificial%20Intelligence%20Functions.html#bang-conv for more information):

void __bang_conv(float *dst, float *src, float *kernel, int channel_input, int height, int width, int kernel_height, int kernel_width, int stride_width, int stride_height, int channel_output)

We notice there is a dot op in triton language, which can easily be mapped to mma instruction in NV backend. Similarly, we can also see the convolution based on the dot operator from the pytorch inductor, The following is the code snippet from pytorch inductor (Delete some code that handles conv1x1 specially and prefetch):

// compute offset for output and the initial ptr for input
...
# -----------------------------------------------------------
# allocate accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for crs in range(0, CRS, BLOCK_K):
    # load inc ptr of x, upade x_ptrs
    off_x_crs = crs + BLOCK_K + tl.arange(0, BLOCK_K)
    delta_xh_ptrs += BLOCK_K
     delta_xw_ptrs += BLOCK_K
     delta_xc_ptrs += BLOCK_K
     delta_xh = tl.load(delta_xh_ptrs, mask=off_x_crs < CRS, other=0)
     delta_xw = tl.load(delta_xw_ptrs, mask=off_x_crs < CRS, other=0)
     delta_xc = tl.load(delta_xc_ptrs, mask=off_x_crs < CRS, other=0)
     off_x_crs_unpacked = (
            delta_xh * stride_xh + delta_xw * stride_xw + delta_xc * stride_xc
      )
      x_ptrs = x + off_x_nhw[:, None] + off_x_crs_unpacked[None, :]

    mask_x = (
        (off_x_n < BATCH)[:, None]
        & (off_x_crs < CRS)[None, :]
        & (off_x_h[:, None] + delta_xh[None, :] >= 0)
        & (off_x_h[:, None] + delta_xh[None, :] < IN_H)
        & (off_x_w[:, None] + delta_xw[None, :] >= 0)
        & (off_x_w[:, None] + delta_xw[None, :] < IN_W)
    )
    mask_w = (off_x_crs < CRS)[:, None] & (off_w_k < KERNEL_N)[None, :]
    # ------ prefetch ------
    # ------ load x ------
    matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0)
    # ------ load w ------
    matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0)
    # ------ matrix multiplication ------
    acc += tl.dot(matrix_x, matrix_w)

For the above code, it can be divided into several steps:

  1. calculate the offset of the data
  2. load the data
  3. call dot to do multiplication and accumulation

How to convert the above logic into conv instruction, there are several possible solutions below, which may not be complete. I don’t know which one Triton officially prefers, or whether there is a better solution.

Proposal 1: Try to identify convolution instructions on Triton IR

The principle of this proposal is very simple, but it may be difficult to implement. It is based on the Triton IR obtained from the triton language, and tries to generate a conv instruction by analyzing the instruction.

But from the above logic, how to recognize that it is a convolution of [KH, KW, PH, PW, ...] ?

Proposal 2: Add conv ops in triton language

Similar to dot op, we introduce conv op in triton language:

tl.conv(input, weights, IC, ....)
Jokeren commented 1 year ago

Hi Ping, we are on vacation from December 22nd until January 2nd, so we will have a discussion on this issue probably after Tuesday, January 3rd. Apologize for any inconvenience and thank you for your understanding.

sethbrin commented 1 year ago

@Jokeren I appreciate your timely response.

Jokeren commented 1 year ago

Hi Ping, I had a discussion with @ptillet today.

We probably prefer proposal 1 instead of defining a new tl.conv operation.

Another alternative way to handle this problem is defining an external function like what we had done for libdevice. Currently, you can only encapsulate the conv routine as bitcode, but invoking it from a binary is also possible with some modifications.

sethbrin commented 1 year ago

@Jokeren Is the other alternative way you said is to define a tl.cambricon.conv external operation?

When integrated into a framework such as inductor, we should define a special template use tl.cambricon.conv op to implement convolution kernel. This solution is reasonable. In addition to these special operators, other regular operators, such as add/mul/reduce, multiple different backends can share the same template implemention.

Jokeren commented 1 year ago

@Jokeren Is the other alternative way you said is to define a tl.cambricon.conv external operation?

Nope, I meant "identify" a conv operation but not "define" a conv operation. Like what you said, this might be difficult but we are worrying that conv is not a native op on other architectures.

It is based on the Triton IR obtained from the triton language, and tries to generate a conv instruction by analyzing the instruction.

cambricon can be another dialect like TritonGPU underneath Triton. What do you think?

sethbrin commented 1 year ago

You gave the example as libdevice, which is equivalent to the concept of defining an external operator. It's directly transparently transmitted to the front end of triton, that is, the front end can be implemented by calling tanh in libdevice through tl.libdevice.tanh , For example, we can define such a kernel,

@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
    x = X
    y = tl.libdevice.tanh(x)
    tl.store(Y + tl.arange(0, BLOCK), y)

Calling the functions in libdevice is an NV-specific function. Its internal implementation is to directly map to the call of the binary implementation in libdevice.bc. It is not applicable to other backends, such as CPU?

So I wonder if the special operator conv in the cambricon backend can also be integrated in a hardware-specific intrinsic way? And we can define a kernel like that

@triton.jit
def kernel(inp, weight, output, padding, stride,  BATCH, ...):
    y = tl.cambricon.conv(inp, weight, ...)
    tl.store(...)
Jokeren commented 1 year ago

If cambrion backend can be encapsulated in a bitcode file, it could be invoked the same way as libdevice. In this way, tl.cambrion is a file that consists of external functions, so maybe you want to name it as libcambrion to differentiate it from the cambrion dialect? Not sure.

Jokeren commented 1 year ago

The extern utility is definitely not mature yet to support all features you would request, but we can work on it together if you think it is doable. I do think enhancing it could benefits other vendors as well.

sethbrin commented 1 year ago

Yes, it should be named libcambricon. Initially, The extern utility seems to meet our needs. We will evaluate it in more detail later and look forward to improving it together.

Jokeren commented 1 year ago

Sounds good!

ptillet commented 1 year ago

Hey!

Ideally, we'd like to find a way to represent convolutions in Triton-IR in a way that makes it performance-portable across H100 and other accelerators. There's still a case to be made for more complex external functions, but my main worry is that we won't be able to come up with hardware-independent calling conventions (e.g., passing tensors through shared memory on GPUs wouldn't be portable to architectures without shared mem).

sethbrin commented 1 year ago

Hi, @ptillet, I discussed this issue with @Jokeren offline.

Since triton plans to add linalg dialect support in the future and there is a linalg.conv_2d operator definition on linalg dialect, so we consider if we can directly lower to cambricon backend through linalg, that is, defining operators like trion.linalg.conv in the frontend, then don't use the current external function mechanism.

In this way, does the calling conventions problem exists?

@Jokeren If something missed, please supplement.

ptillet commented 1 year ago

There is actually no plan to support the linalg dialect, though there is a plan for Triton to be flexible enough to represent everything the linalg dialect can represent.

We should definitely have some tl.conv2d, but the point I was trying to make is that it should probably be implemented as a combination of im2col and matmul, for maximum portability. Then a Cambricon backend could pattern match im2col + matmul into conv2d. Do you think that would work?

sethbrin commented 1 year ago

Yes, we can pattern match im2col+matmul into conv2d.

though there is a plan for Triton to be flexible enough to represent everything the linalg dialect can represent.

As in linalg dialect, there are many kinds of conv, such as different layout combinations, different dimensions and group /depthwise conv, In order to be able to represent all the linalg dialect can represent, I worry that the semantics of im2col will be very complicated. Is this actually transferring the complexity/portability from conv to im2col?

ptillet commented 1 year ago

Yes, it is :p H100 TMAs are actually capable of performing NHWC im2col asynchronously as part of the HBM -> Shared Memory data transfer, and all other architectures that I know of (TPUs, AMD GPUs) also implement convolutions in a similar way. Does Cambricon have dedicated convolution hardware, or is it re-using some sort of matrix multiplication units?

sethbrin commented 1 year ago

On the Cambrian architecture, there are some special convolution hardware instructions, which are NHWC layout. The memory hierarchy is similar to NVIDIA, it also has global memory and shared memory, except that it uses scratchpad memory instead of register files for the temporary storage of calculations.

For conv2d with NHWC layout, we can pattern match im2col(NHWC) + matmul to represent, Is there a time plan to support tl.im2col(maybe NHWC)?

brnorris03 commented 1 year ago

There is actually no plan to support the linalg dialect, though there is a plan for Triton to be flexible enough to represent everything the linalg dialect can represent.

It's been a few months, so I wanted to double check -- there is still no plan to actually generate linalg directly, just make it possible? This (linalg) is something we would like to have as an integration point into our backend. If you are aware of others working on that, it would be good to know, too. Thanks!

Jokeren commented 1 year ago

I believe @sethbrin is working on it

ptillet commented 1 year ago

AFAIK, Linalg dialect is -- in many respects -- higher level than Triton, as it doesn't have pointers and doesn't interface well with unstructured control flow. It is very possible that one could generate linalg program from some kind of restricted Triton specs, but -- unless it can support arbitrary tensors of pointers and branches -- this is probably not something that we would consider upstreaming at the moment. I think it would be much more valuable to go directly from Triton to an even lower dialect closer to your ISA

brnorris03 commented 1 year ago

Thanks, that makes sense.

sethbrin commented 1 year ago

Yes, Linalg dialect is higher level than Triton, we are trying to do triton to linalg conversion and now doing a small POC. Indeed, as @ptillet said, arbitray tensors of pointers and branches are several important issues. We may only can compose part of Triton specs to linalg.

For very fancy expressions, it may be better directly across linalg, but for some compute-bound or performance critical operators, such as conv/matmul/reduce, we still hope to be able to convert to linalg and reuse some official infrastructure to do progressively codegen.

We need to choose this way for two main reasons:

brnorris03 commented 1 year ago

@sethbrin We are in a very similar situation, that is why I asked. Will your POC be open sourced?

sethbrin commented 1 year ago

@brnorris03 Yes, we will submit an RFC when we complete the POC.

sethbrin commented 1 year ago

@ptillet @Jokeren @brnorris03 I submit an RFC from triton to linalg in https://github.com/openai/triton/discussions/1542.