rise-lang / shine

The Shine compiler for the RISE language
https://rise-lang.org
MIT License
73 stars 8 forks source link

Fix SeparateHostAndKernelCode for NatToNatLambda #218

Closed Bastacyclop closed 2 years ago

Bastacyclop commented 2 years ago

Fixes #217

Generated code after fix for gemvFusedAMD:

const char k0_source[] =
""
""
""
"__kernel __attribute__ ((reqd_work_group_size(128, 1, 1)))"
"void k0(global float* restrict output, int n216, int n215, float e221, float e220, const global float* restrict e217, const global float* restrict e219, const global float* restrict e218, local float* restrict x471, local float* restrict x453){"
"  /* Start of moved local vars */"
"  /* End of moved local vars */"
"  /* mapWorkGroup */"
"  for (int wg_id_505 = get_group_id(0); wg_id_505 < n216; wg_id_505 = 8 + wg_id_505) {"
"    /* mapLocal */"
"    /* iteration count is exactly 1, no loop emitted */"
"    int l_id_506 = get_local_id(0);"
"    /* oclReduceSeq */"
"    {"
"      float x478;"
"      x478 = 0.0f;"
"      for (int i_507 = 0; i_507 < (n215 / 128); i_507 = 1 + i_507) {"
"        x478 = (e218[l_id_506 + (128 * i_507)] * e217[(l_id_506 + (128 * i_507)) + (n215 * wg_id_505)]) + x478;"
"      }"
"      "
"      x471[l_id_506] = x478;"
"    }"
"    "
"    barrier(CLK_LOCAL_MEM_FENCE);"
"    /* mapLocal */"
"    for (int l_id_508 = get_local_id(0); l_id_508 < 1; l_id_508 = 128 + l_id_508) {"
"      /* oclReduceSeq */"
"      {"
"        float x460;"
"        x460 = 0.0f;"
"        for (int i_509 = 0; i_509 < 128; i_509 = 1 + i_509) {"
"          x460 = x460 + x471[i_509 + (128 * l_id_508)];"
"        }"
"        "
"        x453[l_id_508] = x460;"
"      }"
"      "
"    }"
"    "
"    barrier(CLK_LOCAL_MEM_FENCE);"
"    /* mapLocal */"
"    for (int l_id_510 = get_local_id(0); l_id_510 < 1; l_id_510 = 128 + l_id_510) {"
"      output[l_id_510 + wg_id_505] = (e220 * x453[l_id_510]) + (e219[wg_id_505] * e221);"
"    }"
"    "
"    barrier(CLK_LOCAL_MEM_FENCE);"
"  }"
"  "
"}"
"";

#define loadKernel(ctx, id)\
  loadKernelFromSource(ctx, #id, id##_source, sizeof(id##_source) - 1)

#include "ocl/ocl.h"
struct foo_t {
  Kernel k0;
};

typedef struct foo_t foo_t;

void foo_init(Context ctx, foo_t* self){
  (*self).k0 = loadKernel(ctx, k0);
}

void foo_destroy(Context ctx, foo_t* self){
  destroyKernel(ctx, (*self).k0);
}

void foo_run(Context ctx, foo_t* self, Buffer moutput, int n215, int n216, Buffer me217, Buffer me218, Buffer me219, float e220, float e221){
  {
    DeviceBuffer b0 = deviceBufferSync(ctx, moutput, n216 * sizeof(float), DEVICE_WRITE);
    DeviceBuffer b5 = deviceBufferSync(ctx, me217, n216 * (n215 * sizeof(float)), DEVICE_READ);
    DeviceBuffer b6 = deviceBufferSync(ctx, me219, n216 * sizeof(float), DEVICE_READ);
    DeviceBuffer b7 = deviceBufferSync(ctx, me218, n215 * sizeof(float), DEVICE_READ);
    const size_t global_size[3] = (const size_t[3]){1024, 1, 1};
    const size_t local_size[3] = (const size_t[3]){128, 1, 1};
    const KernelArg args[10] = (const KernelArg[10]){KARG(b0), KARG(n216), KARG(n215), KARG(e221), KARG(e220), KARG(b5), KARG(b6), KARG(b7), LARG(128 * sizeof(float)), LARG(1 * sizeof(float))};
    launchKernel(ctx, (*self).k0, global_size, local_size, 10, args);
  }

}

void foo_init_run(Context ctx, Buffer moutput, int n215, int n216, Buffer me217, Buffer me218, Buffer me219, float e220, float e221){
  foo_t foo;
  foo_init(ctx, &foo);
  foo_run(ctx, &foo, moutput, n215, n216, me217, me218, me219, e220, e221);
  foo_destroy(ctx, &foo);
}
johanneslenfers commented 2 years ago

I added host-code generation checks for the other gemv versions to make it more consistent.