mstksg / backprop

Heterogeneous automatic differentiation ("backpropagation") in Haskell
https://backprop.jle.im
BSD 3-Clause "New" or "Revised" License
181 stars 22 forks source link

Backpropagating through "stateful" operations. #16

Open jh14778 opened 4 years ago

jh14778 commented 4 years ago

I am trying to use this library with a third party C library, but I need to carry some additional context to perform the computations.

My issue boils down to how do I integrate library calls like:

-- Haskell wrapper function signature 
zero :: (Int, Int) -> StateT IO Context Matrix
one :: (Int, Int) -> StateT IO Context Matrix
add :: Matrix -> Matrix -> StateT IO Context Matrix
matMul :: Matrix -> Matrix -> StateT IO Context Matrix

-- FFI signature 
zero :: Ptr Context -> CInt -> CInt -> IO (Ptr Matrix)
one :: Ptr Context -> CInt -> CInt -> IO (Ptr Matrix)
add :: Ptr Context -> Ptr Matrix -> Ptr Matrix -> IO (Ptr Matrix)
matmul :: Ptr Context -> Ptr Matrix -> Ptr Matrix -> IO (Ptr Matrix)

There are no standalone functions to zero, one, or add a Matrix without carrying the Context.

ocramz commented 4 years ago

I've only recently started using backprop, but pretty much like ad it assumes that the computation to be differentiated is pure. I can't quite follow the internals of the library well enough to be able to tell why that is the case, though.

jh14778 commented 4 years ago

I suspected that was the case.

My library calls are "pure" as far as the computed values are concerned. Unfortunately, there are observable side-effects on the context (unsafePerformIO might not be acceptable here either).

Perhaps this is more of a feature request. I'm happy for this to be closed, if there's no plan for this to be supported.

ivanovs-4 commented 4 years ago

Maybe the MVar trick will do:

import Control.Concurrent
import Control.Monad.State
import Data.Tuple
import System.IO.Unsafe

type Context = Int
type Matrix = [[Double]]
newtype Ctx = Ctx { unCtx :: MVar Context }

newCtx :: IO Context -> IO Ctx
newCtx ioc = fmap Ctx $ newMVar =<< ioc

zero :: (Int, Int) -> StateT Context IO Matrix
zero xy@(x,y) = do
    modify' succ
    s <- get
    lift $ print $ (s, xy)
    pure $ replicate y $ replicate x $ 0

zero' :: Ctx -> (Int, Int) -> Matrix
zero' ctx xy = wrap ctx $ zero xy

wrap :: Ctx -> StateT Context IO a -> a
wrap ctx ma = unsafePerformIO $
    modifyMVar (unCtx ctx) $ fmap swap . runStateT ma

main :: IO ()
main = do
  ctx <- newCtx $ pure 0
  print $ zero' ctx (2,3)
  print $ zero' ctx (1,1)
ocramz commented 3 years ago

@jh14778 @ivanovs-4 I just recalled : you could also wrap your effects inside ABP , which will provide the Backprop instance for you : https://hackage.haskell.org/package/backprop-0.2.6.4/docs/Numeric-Backprop.html#t:ABP