triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.06k stars 1.6k forks source link

AOT Compiler: use grid as part of C api's args #3488

Open yinqiwen opened 6 months ago

yinqiwen commented 6 months ago

Currently, aot compiler users must provide grid args to generate the C api sources simply as constant value to launch kernels.

Is it possible to put the grid as C api's args? That make the generated code more dynamic, users can compute the grid at runtime, and then invoke the generated code.

/*
['BLOCK_M=128', 'BLOCK_N=128', 'num_warps=1', 'num_stages=3']
*/
CUresult silu_and_mul_f3d435fd_0d123456(CUstream stream, CUdeviceptr input_ptr, int32_t stride_input_m, int32_t stride_input_n, int32_t stride_output_m, int32_t stride_output_n, int32_t size_m, int32_t size_n) {
    if (silu_and_mul_f3d435fd_0d123456_func == NULL)
       load_silu_and_mul_f3d435fd_0d123456();
    unsigned int gX = 32;
    unsigned int gY = 32;
    unsigned int gZ = 1;
    void *args[7] = { &input_ptr, &stride_input_m, &stride_input_n, &stride_output_m, &stride_output_n, &size_m, &size_n };
    // TODO: shared memory
    if(gX * gY * gZ > 0)
      return cuLaunchKernel(silu_and_mul_f3d435fd_0d123456_func, gX, gY, gZ, 1 * 32, 1, 1, 0, stream, args, NULL);
}
GetUpEarlier commented 1 month ago

grid can be a simple C expression. I think we should separate tunable algo arg and problem args, allow compute grid value with extra problem args.