wellposed / hblas

haskell bindings for blas and lapack
www.wellposed.com
BSD 3-Clause "New" or "Revised" License
49 stars 19 forks source link

Expanding into support for gemv #13

Closed mstksg closed 10 years ago

mstksg commented 10 years ago

So I'm looking into modifying/tweaking the current gemmAbstraction from BLAS.hs and adapting it for gemv; any sort of pointers on what aspects are likely to need changing?

cartazio commented 10 years ago

good question (note i've not added proper blas vectors yet, so lets just assume we're given Data.Vector.Storable.Mutable vectors of stride one)

first lets work through how gemmAbstraction works! it lives over here https://github.com/wellposed/hblas/blob/master/src/Numerical/HBLAS/BLAS.hs#L76-L107


flopsThreshold = 10000
gemmComplexity a b c = a * b * c  -- this will be wrong by some constant factor, albeit a small one

-- this covers the ~6 cases for checking the dimensions for GEMM quite nicely
isBadGemm tra trb  ax ay bx by cx cy = isBadGemmHelper (cds tra (ax,ay)) (cds trb (bx,by) )  (cx,cy)
    where 
    cds = coordSwapper 
    isBadGemmHelper !(ax,ay) !(bx,by) !(cx,cy) =  (minimum [ax, ay, bx, by, cx ,cy] <= 0)  
        || not (  cy ==  ay && cx == bx && ax == by)

coordSwapper :: Transpose -> (a,a)-> (a,a)
coordSwapper NoTranspose (a,b) = (a,b)
coordSwapper ConjNoTranspose (a,b) = (a,b) 
coordSwapper Transpose (a,b) = (b,a)
coordSwapper ConjTranspose (a,b) = (b,a)

encodeNiceOrder :: SOrientation x  -> CBLAS_ORDERT
encodeNiceOrder SRow= encodeOrder  BLASRowMajor
encodeNiceOrder SColumn= encodeOrder BLASColMajor

encodeFFITranpose :: Transpose -> CBLAS_TRANSPOSET
encodeFFITranpose  x=  encodeTranpose $ encodeNiceTranpose x 

encodeNiceTranpose :: Transpose -> BLAS_Transpose
encodeNiceTranpose x = case x of 
        NoTranspose -> BlasNoTranspose
        Transpose -> BlasTranspose
        ConjTranspose -> BlasConjTranspose
        ConjNoTranspose -> BlasConjNoTranspose

--data BLAS_Tranpose = BlasNoTranspose | BlasTranpose | BlasConjTranspose | BlasConjNoTranpose 
--data Tranpose = NoTranpose | Tranpose | ConjTranpose | ConjNoTranpose

type GemmFun el orient s m = Transpose ->Transpose ->  el -> el  -> MutDenseMatrix s orient el
  ->   MutDenseMatrix s orient el  ->  MutDenseMatrix s orient el -> m ()

{-
A key design goal of this ffi is to provide *safe* throughput guarantees 
for a concurrent application built on top of these apis, while evading
any overheads for providing such safety. Accordingly, on inputs sizes

-}

---- |  Matrix mult for general dense matrices
--type GemmFunFFI scale el = CBLAS_ORDERT ->   CBLAS_TRANSPOSET -> CBLAS_TRANSPOSET->
        --CInt -> CInt -> CInt -> {- scal A * B -} scale  -> {- Matrix A-} Ptr el  -> CInt -> {- B -}  Ptr el -> CInt-> 
            --scale -> {- C -}  Ptr el -> CInt -> IO ()
--type GemmFun = MutDenseMatrix or el ->  MutDenseMatrix or el ->   MutDenseMatrix or el -> m ()

{-# NOINLINE gemmAbstraction #-}
gemmAbstraction:: (SM.Storable el, PrimMonad m) =>  String -> 
    GemmFunFFI scale el -> GemmFunFFI scale el -> (el -> (scale -> m ())->m ()) -> forall orient . GemmFun el orient (PrimState m) m 
gemmAbstraction gemmName gemmSafeFFI gemmUnsafeFFI constHandler = go 
  where 
    shouldCallFast :: Int -> Int -> Int -> Bool                         
    shouldCallFast cy cx ax = flopsThreshold >= gemmComplexity cy cx ax

    go  tra trb  alpha beta 
        (MutableDenseMatrix ornta ax ay astride abuff) 
        (MutableDenseMatrix _ bx by bstride bbuff) 
        (MutableDenseMatrix _ cx cy cstride cbuff) 
            |  isBadGemm tra trb  ax ay bx by cx cy = error $! "bad dimension args to GEMM: ax ay bx by cx cy: " ++ show [ax, ay, bx, by, cx ,cy]
            | SM.overlaps abuff cbuff || SM.overlaps bbuff cbuff = 
                    error $ "the read and write inputs for: " ++ gemmName ++ " overlap. This is a programmer error. Please fix." 
            | otherwise  = 
                {-  FIXME : Add Sharing check that also errors out for now-}
                unsafeWithPrim abuff $ \ap -> 
                unsafeWithPrim bbuff $ \bp ->  
                unsafeWithPrim cbuff $ \cp  -> 
                constHandler alpha $  \alphaPtr ->   
                constHandler beta $ \betaPtr ->  
                    do  (ax,ay) <- return $ coordSwapper tra (ax,ay)
                        --- dont need to swap b, info is in a and c
                        --- c doesn't get implicitly transposed
                        blasOrder <- return $ encodeNiceOrder ornta -- all three are the same orientation
                        rawTra <- return $  encodeFFITranpose tra 
                        rawTrb <- return $   encodeFFITranpose trb
                                 -- example of why i want to switch to singletones
                        unsafePrimToPrim $!  (if shouldCallFast cy cx ax then gemmUnsafeFFI  else gemmSafeFFI ) 
                            blasOrder rawTra rawTrb (fromIntegral cy) (fromIntegral cx) (fromIntegral ax) 
                                alphaPtr ap  (fromIntegral astride) bp (fromIntegral bstride) betaPtr  cp (fromIntegral cstride)
cartazio commented 10 years ago

so what do you want to understand about the above! ask!

cartazio commented 10 years ago

notice that the result function go has the type

forall orient . GemmFun el orient (PrimState m) m 
cartazio commented 10 years ago

the wrapped up dgemm is defined like so

dgemm :: PrimMonad m=> 
     Transpose ->Transpose ->  Double -> Double -> MutDenseMatrix (PrimState m) orient Double  ->   MutDenseMatrix (PrimState m) orient Double   ->  MutDenseMatrix (PrimState m) orient Double -> m ()
dgemm = gemmAbstraction "dgemm" cblas_dgemm_unsafe cblas_dgemm_safe (\x f -> f x )   
cartazio commented 10 years ago

the last function bit

(\x f -> f x) 

comes from how in the complex case we need to actually pack the complex number into a struct + pointer to pass along to the ffi'd code

the transpose type is in https://github.com/wellposed/hblas/blob/master/src/Numerical/HBLAS/MatrixTypes.hs#L100

data Transpose = NoTranspose | Transpose | ConjTranspose | ConjNoTranspose
    deriving(Typeable,Eq,Show)
cartazio commented 10 years ago

its worth reemphasizing

gemmAbstraction:: (SM.Storable el, PrimMonad m) =>  String -> 
    GemmFunFFI scale el -> GemmFunFFI scale el -> (el -> (scale -> m ())->m ()) -> forall orient . GemmFun el orient (PrimState m) m 
cartazio commented 10 years ago

GemmFunFFI is at https://github.com/wellposed/hblas/blob/master/src/Numerical/HBLAS/BLAS/FFI.hs#L440

-- |  Matrix mult for general dense matrices
type GemmFunFFI scale el = CBLAS_ORDERT ->   CBLAS_TRANSPOSET -> CBLAS_TRANSPOSET->
        CInt -> CInt -> CInt -> {- scal A * B -} scale  -> {- Matrix A-} Ptr el  -> CInt -> {- B -}  Ptr el -> CInt-> 
            scale -> {- C -}  Ptr el -> CInt -> IO ()

{- C := alpha*op( A )*op( B ) + beta*C ,  -}

-- matrix mult!
foreign import ccall unsafe "cblas_sgemm" 
    cblas_sgemm_unsafe :: GemmFunFFI Float Float

foreign import ccall unsafe "cblas_dgemm" 
    cblas_dgemm_unsafe :: GemmFunFFI Double Double

foreign import ccall unsafe "cblas_cgemm" 
    cblas_cgemm_unsafe :: GemmFunFFI (Ptr(Complex Float)) (Complex Float)

foreign import ccall unsafe "cblas_zgemm" 
    cblas_zgemm_unsafe :: GemmFunFFI (Ptr (Complex Double)) (Complex Double)

-- safe ffi variant for large inputs
foreign import ccall "cblas_sgemm" 
    cblas_sgemm_safe :: GemmFunFFI Float Float

foreign import ccall "cblas_dgemm" 
    cblas_dgemm_safe :: GemmFunFFI Double Double

foreign import ccall "cblas_cgemm" 
    cblas_cgemm_safe :: GemmFunFFI (Ptr(Complex Float)) (Complex Float)

foreign import ccall "cblas_zgemm" 
    cblas_zgemm_safe :: GemmFunFFI (Ptr (Complex Double)) (Complex Double)
cartazio commented 10 years ago

to quote the mkl blas docs

The ?gemm routines perform a matrix-matrix operation with general matrices. The operation is defined as

C := alpha*op(A)*op(B) + beta*C,
where:

op(x) is one of op(x) = x, or op(x) = x', or op(x) = conjg(x'),

the choice of op(x) is determined by

Transpose = NoTranspose | Transpose | ConjTranspose | ConjNoTranspose
cartazio commented 10 years ago

closing This for now, since it kinda documents most things that someone whos familiar with the haskell FFI will need help with

cartazio commented 10 years ago

a key snippet of gemmAbstraction for understanding the use of both the unsafe and safe ffi stuff (i'm reproducing it from above)

         unsafePrimToPrim $!  (if shouldCallFast cy cx ax then gemmUnsafeFFI  else gemmSafeFFI )

the idea being "when the compute is sooooo fast (<= 1-10µs of FLOPS), use the unsafeFFI, but anything above the complexity threshold, the SAFE ffi call is cheap enough that we should do that instead"

mstksg commented 10 years ago

What high-level type of sgemv do we want? Something perhaps like

sgemv :: PrimMonad m
    => Transpose -> Float -> Float
    -> MutDenseMatrix (PrimState m) orient Float     -- A
    -> MVector (PrimState m) Float          -- x
    -> MVector (PrimState m) Float          -- y
    -> MVector (PrimState m) Float          -- result
    -> m ()

?

cartazio commented 10 years ago

yup, thats a good start, just implement that and dgemv, ignore complex for now while you're learning

cartazio commented 10 years ago

woops, should be


sgemv :: PrimMonad m
    => Transpose -> Float -> Float
    -> MutDenseMatrix (PrimState m) orient Float
    -> MVector (PrimState m) Float
    -> MVector (PrimState m) Float
    -> m ()
mstksg commented 10 years ago

Huh, realized that y is both an input vector and a result vector, so there should only be two, yeah.

cartazio commented 10 years ago

for those following this thread at home, a good BLAS / lapack doc can found here http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/index.htm

with the docs on sgemv in cblas here http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-25178576-05F1-4A33-8A0E-3694F0CCD242.htm