HazyResearch / ThunderKittens

Tile primitives for speedy kernels
MIT License
1.52k stars 58 forks source link

Template error #43

Open Hprairie opened 3 months ago

Hprairie commented 3 months ago

Hey, really awesome repo, I have been trying to setup a Thunder Kittens env, however, I am struggling as I am getting import errors with clangd. When including kittens, I found that it would error out, and after doing some digging, then I found the error to be on line 52 and 55 of /ops/warp/shared/tile/conversions.cuh, where I got Use 'template' keyword to treat 'subtile' as a dependent template name. I am guessing that this is def an issue with my clang env and not with thunderkittens, but I was wondering if you could share the .clangd file used with the project so that I can fix this?

Thanks so much!

Hprairie commented 3 months ago

I have also edited the code in place to get the following which fixes clang, but I will have to see if it messes with compilation.

/* ----------  SUBTILE  ---------- */

/**
* @brief Returns a reference to a subtile of the given shared tile.
*
* @tparam subtile_height The height of the subtile.
* @tparam subtile_width The width of the subtile.
* @tparam ST The type of the input tile, which must satisfy the ducks::st::all concept.
* @param src The input tile.
* @param row_idx The row index of the subtile, in units of subtile_height*16 elements.
* @param col_idx The col index of the subtile, in units of subtile_width*16 elements.
* @return A reference to the subtile.
*
* @note The subtile {height, width} must evenly divide the tile {height, width}.
*/
template<int subtile_height, int subtile_width, ducks::st::all ST>
__device__ inline typename ST::template subtile<subtile_height, subtile_width> subtile_inplace(ST &src, int row_idx, int col_idx) {
    static_assert(ST::height % subtile_height == 0);
    static_assert(ST::width % subtile_width == 0);
    return typename ST::template subtile<subtile_height, subtile_width>(
        &src[0], subtile_height*16*row_idx, subtile_width*16*col_idx
    );
}