chsasank / llama.lisp

Lisp dialect designed for HPC and AI
GNU Lesser General Public License v2.1
14 stars 6 forks source link

matmul implementations in c-lisp #62

Closed adi-lb-phoenix closed 2 months ago

adi-lb-phoenix commented 3 months ago

…ts in matmul.out and turnt.toml file which contains the command to execute .sexp files in this folder . There are 3 files here matmul.sexp , matmul.out and turnt.toml. matmul.sexp contains the implementations of matmul operations present in llama2.c, which is matrix multiplications itself. matmul.out contains the output to tests the matmul.sexp . The matrix multiplication is done with two matrices of size 1010 and 101 . The final matmul.exp output contains 10 numbers.

This below is the .C file used to generate the outputs present in matmul.out. The same C file logic has been implemented in matmul.sexp so that we can verify our implementations . matmul_testOUTgenerator.txt

GlowingScrewdriver commented 2 months ago

In your initialize function:

   (define ((initialise void) (xout (ptr float)) (x (ptr float)) (w (ptr float)) (column int) (row int))

Why is xout a parameter? I don't see it being used anywhere in the function body.

adi-lb-phoenix commented 2 months ago

Yes you are right about (xout (ptr float)) . This is not being used anywhere in the initialize body. I will make that change .

adi-lb-phoenix commented 2 months ago

The Below code represents the C version of the matmul.sexp.

#include<stdlib.h>
#include<stdio.h>

void matmul(float* xout, float* x, float* w, int n, int d) {
    // W (d,n) @ x (n,) -> xout (d,)
    // by far the most amount of time is spent inside this little function
    int i;
    for (i = 0; i < d; i++) {
        float val = 0.0f;
        for (int j = 0; j < n; j++) {
            val += w[i * n + j] * x[j];
        }
        xout[i] = val;
        printf("%f\n", xout[i]);
    }
}

void initialise(float *xout, float *x, float *w, int column , int row){
    for (int i = 0; i < row; i++){
        x[i] = i;
        for (int j = 0; j<column; j++){
            w[i*row + j] = i + j ;
        }

    }

}
int main(int argc, char * argv[]){
    float * xout;
    float * x ;
    float * w ;
    int n = 10 ;
    int d = 10 ;
    float * ptr[3];
    xout = calloc(d , sizeof(float)) ;
    w = calloc(d*n, sizeof(float));
    x = calloc(n, sizeof(float));
    initialise (xout, x, w, n, d) ;
    matmul(xout, x, w, n, d);
    free(xout);
    free(w);
    free(x);
    return 0 ;
}