tweag / monad-bayes

A library for probabilistic programming in Haskell.
MIT License
407 stars 62 forks source link

Why reimplement Functor, Applicative, Monad instances for known transformers? #209

Open turion opened 1 year ago

turion commented 1 year ago

Many monads defined in monad-bayes are known transformers. For example:

-- | Collection of random variables sampler during the program's execution.
data Trace a = Trace
  { -- | Sequence of random variables sampler during the program's execution.
    variables :: [Double],
    --
    output :: a,
    -- | The probability of observing this particular sequence.
    density :: Log Double
  }

This is isomorphic to:

-- | Collection of random variables sampler during the program's execution.
data TraceData = TraceData
  { -- | Sequence of random variables sampler during the program's execution.
    variables :: [Double],
    -- | The probability of observing this particular sequence.
    density :: Log Double
  }
  deriving (Semigroup, Monoid)

newtype Trace a = Trace { getTrace :: Writer TraceData a }

The advantage is that one gets correct and efficient Functor, Monad,... instances that way, and saves code. (It's notoriously difficult to get a performant Writer, so it's a good idea to use existing code instead of reimplementing it.)

Similarly:

-- | Tracing monad that records random choices made in the program.
data Traced m a = Traced
  { -- | Run the program with a modified trace.
    model :: Weighted (FreeSampler Identity) a,
    -- | Record trace and output.
    traceDist :: m (Trace a)
  }

This is the same as promoting Trace to a transformer and taking the product:

newtype TraceT m a = TraceT { getTrace :: WriterT TraceData m a }

-- | Tracing monad that records random choices made in the program.
newtype Traced m a = Traced { getTraced :: Product (Weighted (FreeSampler Identity)) (TraceT m) }

Again one gets all the instances for free now. The haddocks of the record fields can be moved to a custom constructor function:

traced ::
  -- | Run the program with a modified trace.
  Weighted (FreeSampler Identity) a ->
  -- | Record trace and output.
  m (Trace a) ->
  Traced m a
traced = ...
reubenharry commented 1 year ago

Highly in favour of this refactor!