explosion / cython-blis

💥 Fast matrix-multiplication as a self-contained Python library – no system dependencies!
Other
218 stars 37 forks source link

Where are the docs ? #103

Closed gabrielfougeron closed 7 months ago

gabrielfougeron commented 7 months ago

Hi,

I'm trying to replicate the exemple in the readme file, and wrap gemm in my own cython code. Here is what I came up with for the matrix multiplication c = a @ b:

cimport blis.cy

cpdef Cython_blis(
    double[:,::1] a,
    double[:,::1] b,
    double[:,::1] c
) :

    cdef int p = a.shape[0]
    cdef int q = b.shape[1]
    cdef int r = a.shape[1]

    blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.NO_TRANSPOSE,
                p, q, r,
                1.0, a, q, 1,
                b, p, 1,
                0.0, c, p, 1
            )

upon build, cython gives me this error:

Error compiling Cython file:                                            
------------------------------------------------------------
...                                 
cdef int p = a.shape[0]
cdef int q = b.shape[1]
cdef int r = a.shape[1]
blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.NO_TRANSPOSE,                                                                                                                 
                  ^                                                                                                                                              
------------------------------------------------------------   
no suitable method found   

I'm guessing I need to point cython to the correct include files, but I could not find any blis.get_include() or similar function.

Where can I find a minium working exemple of a project using cython-blis to better understand how to include it to my own project ?

gabrielfougeron commented 7 months ago

I think I found the source of my problem: I needed to give the address of the memoryview, not the memoryview itself:

cimport blis.cy

cpdef Cython_blis(
    double[:,::1] a,
    double[:,::1] b,
    double[:,::1] c
) :

    cdef int p = a.shape[0]
    cdef int q = b.shape[1]
    cdef int r = a.shape[1]

    blis.cy.gemm(blis.cy.NO_TRANSPOSE, blis.cy.NO_TRANSPOSE,
                p, q, r,
                1.0, &a[0,0], q, 1,
                &b[0,0], p, 1,
                0.0, &c[0,0], p, 1
            )