Closed mstksg closed 10 years ago
good question (note i've not added proper blas vectors yet, so lets just assume we're given Data.Vector.Storable.Mutable vectors of stride one)
first lets work through how gemmAbstraction works! it lives over here https://github.com/wellposed/hblas/blob/master/src/Numerical/HBLAS/BLAS.hs#L76-L107
flopsThreshold = 10000
gemmComplexity a b c = a * b * c -- this will be wrong by some constant factor, albeit a small one
-- this covers the ~6 cases for checking the dimensions for GEMM quite nicely
isBadGemm tra trb ax ay bx by cx cy = isBadGemmHelper (cds tra (ax,ay)) (cds trb (bx,by) ) (cx,cy)
where
cds = coordSwapper
isBadGemmHelper !(ax,ay) !(bx,by) !(cx,cy) = (minimum [ax, ay, bx, by, cx ,cy] <= 0)
|| not ( cy == ay && cx == bx && ax == by)
coordSwapper :: Transpose -> (a,a)-> (a,a)
coordSwapper NoTranspose (a,b) = (a,b)
coordSwapper ConjNoTranspose (a,b) = (a,b)
coordSwapper Transpose (a,b) = (b,a)
coordSwapper ConjTranspose (a,b) = (b,a)
encodeNiceOrder :: SOrientation x -> CBLAS_ORDERT
encodeNiceOrder SRow= encodeOrder BLASRowMajor
encodeNiceOrder SColumn= encodeOrder BLASColMajor
encodeFFITranpose :: Transpose -> CBLAS_TRANSPOSET
encodeFFITranpose x= encodeTranpose $ encodeNiceTranpose x
encodeNiceTranpose :: Transpose -> BLAS_Transpose
encodeNiceTranpose x = case x of
NoTranspose -> BlasNoTranspose
Transpose -> BlasTranspose
ConjTranspose -> BlasConjTranspose
ConjNoTranspose -> BlasConjNoTranspose
--data BLAS_Tranpose = BlasNoTranspose | BlasTranpose | BlasConjTranspose | BlasConjNoTranpose
--data Tranpose = NoTranpose | Tranpose | ConjTranpose | ConjNoTranpose
type GemmFun el orient s m = Transpose ->Transpose -> el -> el -> MutDenseMatrix s orient el
-> MutDenseMatrix s orient el -> MutDenseMatrix s orient el -> m ()
{-
A key design goal of this ffi is to provide *safe* throughput guarantees
for a concurrent application built on top of these apis, while evading
any overheads for providing such safety. Accordingly, on inputs sizes
-}
---- | Matrix mult for general dense matrices
--type GemmFunFFI scale el = CBLAS_ORDERT -> CBLAS_TRANSPOSET -> CBLAS_TRANSPOSET->
--CInt -> CInt -> CInt -> {- scal A * B -} scale -> {- Matrix A-} Ptr el -> CInt -> {- B -} Ptr el -> CInt->
--scale -> {- C -} Ptr el -> CInt -> IO ()
--type GemmFun = MutDenseMatrix or el -> MutDenseMatrix or el -> MutDenseMatrix or el -> m ()
{-# NOINLINE gemmAbstraction #-}
gemmAbstraction:: (SM.Storable el, PrimMonad m) => String ->
GemmFunFFI scale el -> GemmFunFFI scale el -> (el -> (scale -> m ())->m ()) -> forall orient . GemmFun el orient (PrimState m) m
gemmAbstraction gemmName gemmSafeFFI gemmUnsafeFFI constHandler = go
where
shouldCallFast :: Int -> Int -> Int -> Bool
shouldCallFast cy cx ax = flopsThreshold >= gemmComplexity cy cx ax
go tra trb alpha beta
(MutableDenseMatrix ornta ax ay astride abuff)
(MutableDenseMatrix _ bx by bstride bbuff)
(MutableDenseMatrix _ cx cy cstride cbuff)
| isBadGemm tra trb ax ay bx by cx cy = error $! "bad dimension args to GEMM: ax ay bx by cx cy: " ++ show [ax, ay, bx, by, cx ,cy]
| SM.overlaps abuff cbuff || SM.overlaps bbuff cbuff =
error $ "the read and write inputs for: " ++ gemmName ++ " overlap. This is a programmer error. Please fix."
| otherwise =
{- FIXME : Add Sharing check that also errors out for now-}
unsafeWithPrim abuff $ \ap ->
unsafeWithPrim bbuff $ \bp ->
unsafeWithPrim cbuff $ \cp ->
constHandler alpha $ \alphaPtr ->
constHandler beta $ \betaPtr ->
do (ax,ay) <- return $ coordSwapper tra (ax,ay)
--- dont need to swap b, info is in a and c
--- c doesn't get implicitly transposed
blasOrder <- return $ encodeNiceOrder ornta -- all three are the same orientation
rawTra <- return $ encodeFFITranpose tra
rawTrb <- return $ encodeFFITranpose trb
-- example of why i want to switch to singletones
unsafePrimToPrim $! (if shouldCallFast cy cx ax then gemmUnsafeFFI else gemmSafeFFI )
blasOrder rawTra rawTrb (fromIntegral cy) (fromIntegral cx) (fromIntegral ax)
alphaPtr ap (fromIntegral astride) bp (fromIntegral bstride) betaPtr cp (fromIntegral cstride)
so what do you want to understand about the above! ask!
notice that the result function go
has the type
forall orient . GemmFun el orient (PrimState m) m
the wrapped up dgemm is defined like so
dgemm :: PrimMonad m=>
Transpose ->Transpose -> Double -> Double -> MutDenseMatrix (PrimState m) orient Double -> MutDenseMatrix (PrimState m) orient Double -> MutDenseMatrix (PrimState m) orient Double -> m ()
dgemm = gemmAbstraction "dgemm" cblas_dgemm_unsafe cblas_dgemm_safe (\x f -> f x )
the last function bit
(\x f -> f x)
comes from how in the complex case we need to actually pack the complex number into a struct + pointer to pass along to the ffi'd code
the transpose type is in https://github.com/wellposed/hblas/blob/master/src/Numerical/HBLAS/MatrixTypes.hs#L100
data Transpose = NoTranspose | Transpose | ConjTranspose | ConjNoTranspose
deriving(Typeable,Eq,Show)
its worth reemphasizing
gemmAbstraction:: (SM.Storable el, PrimMonad m) => String ->
GemmFunFFI scale el -> GemmFunFFI scale el -> (el -> (scale -> m ())->m ()) -> forall orient . GemmFun el orient (PrimState m) m
GemmFunFFI is at https://github.com/wellposed/hblas/blob/master/src/Numerical/HBLAS/BLAS/FFI.hs#L440
-- | Matrix mult for general dense matrices
type GemmFunFFI scale el = CBLAS_ORDERT -> CBLAS_TRANSPOSET -> CBLAS_TRANSPOSET->
CInt -> CInt -> CInt -> {- scal A * B -} scale -> {- Matrix A-} Ptr el -> CInt -> {- B -} Ptr el -> CInt->
scale -> {- C -} Ptr el -> CInt -> IO ()
{- C := alpha*op( A )*op( B ) + beta*C , -}
-- matrix mult!
foreign import ccall unsafe "cblas_sgemm"
cblas_sgemm_unsafe :: GemmFunFFI Float Float
foreign import ccall unsafe "cblas_dgemm"
cblas_dgemm_unsafe :: GemmFunFFI Double Double
foreign import ccall unsafe "cblas_cgemm"
cblas_cgemm_unsafe :: GemmFunFFI (Ptr(Complex Float)) (Complex Float)
foreign import ccall unsafe "cblas_zgemm"
cblas_zgemm_unsafe :: GemmFunFFI (Ptr (Complex Double)) (Complex Double)
-- safe ffi variant for large inputs
foreign import ccall "cblas_sgemm"
cblas_sgemm_safe :: GemmFunFFI Float Float
foreign import ccall "cblas_dgemm"
cblas_dgemm_safe :: GemmFunFFI Double Double
foreign import ccall "cblas_cgemm"
cblas_cgemm_safe :: GemmFunFFI (Ptr(Complex Float)) (Complex Float)
foreign import ccall "cblas_zgemm"
cblas_zgemm_safe :: GemmFunFFI (Ptr (Complex Double)) (Complex Double)
to quote the mkl blas docs
The ?gemm routines perform a matrix-matrix operation with general matrices. The operation is defined as
C := alpha*op(A)*op(B) + beta*C,
where:
op(x) is one of op(x) = x, or op(x) = x', or op(x) = conjg(x'),
the choice of op(x) is determined by
Transpose = NoTranspose | Transpose | ConjTranspose | ConjNoTranspose
closing This for now, since it kinda documents most things that someone whos familiar with the haskell FFI will need help with
a key snippet of gemmAbstraction for understanding the use of both the unsafe and safe ffi stuff (i'm reproducing it from above)
unsafePrimToPrim $! (if shouldCallFast cy cx ax then gemmUnsafeFFI else gemmSafeFFI )
the idea being "when the compute is sooooo fast (<= 1-10µs of FLOPS), use the unsafeFFI, but anything above the complexity threshold, the SAFE ffi call is cheap enough that we should do that instead"
What high-level type of sgemv do we want? Something perhaps like
sgemv :: PrimMonad m
=> Transpose -> Float -> Float
-> MutDenseMatrix (PrimState m) orient Float -- A
-> MVector (PrimState m) Float -- x
-> MVector (PrimState m) Float -- y
-> MVector (PrimState m) Float -- result
-> m ()
?
yup, thats a good start, just implement that and dgemv, ignore complex for now while you're learning
woops, should be
sgemv :: PrimMonad m
=> Transpose -> Float -> Float
-> MutDenseMatrix (PrimState m) orient Float
-> MVector (PrimState m) Float
-> MVector (PrimState m) Float
-> m ()
Huh, realized that y is both an input vector and a result vector, so there should only be two, yeah.
for those following this thread at home, a good BLAS / lapack doc can found here http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/index.htm
with the docs on sgemv in cblas here http://software.intel.com/sites/products/documentation/hpc/mkl/mklman/GUID-25178576-05F1-4A33-8A0E-3694F0CCD242.htm
So I'm looking into modifying/tweaking the current gemmAbstraction from BLAS.hs and adapting it for gemv; any sort of pointers on what aspects are likely to need changing?