turion / rhine

Haskell Functional Reactive Programming framework with type-level clocks
http://hackage.haskell.org/package/rhine
123 stars 21 forks source link

Inference resampling buffers #289

Open turion opened 9 months ago

turion commented 9 months ago

@reubenharry your ideas in https://github.com/turion/rhine/issues/281 have given me an idea in turn. We've discussed a few times that really resampling buffers in rhine-bayes should arise from inference. I think I have a proposal how this should look like.

Principle

A resampling buffer is basically an effectful Moore machine: It has a method to put new data in it (and thus alter its state), and a method to get data from its state. In rhine-bayes, these two methods should have the following meaning:

Pseudocode

To implement a buffer as a particle filter, we need this data:

-- | A stochastic process whose development in time we want to use for resampling
model :: StochasticProcess time a
model = ...

-- | Given state @a@, what is the likelihood of @b@ occurring?
--   This is used to weight the different particles later.
observation :: MonadFactor m => a -> b -> m ()
observation = ...

myInferenceBuffer :: MonadMeasure m => ResamplingBuffer m clA clB b [(a, Probability)]
myInferenceBuffer = inferenceBuffer model observation

Example in use

Let's think about how the brownian motion example would simplify. If, for the moment, we disregard the varying temperature, the whole rhine currently looks like this schematically:

simulation >-- keepLast default --> inference >-- keepLast default --> visualization

This setup is unsatisfactory in a few ways: The usage of keepLast is ad hoc, inference may run a couple of times on the same values if it ticks more often than the simulation. Also, we keep creating estimates of the current state more often than we can visualize them, which is wasteful.

I think it would be much better if the inference is activated exactly on every simulation step, and the current estimate is retrieved exactly on every step of the visualization! This would be achieved with this setup:

simulation >-- myInferenceBuffer --> visualization

Much simpler, no ad hoc choices, I believe better performance in runtime as well as in quality.

Variation

This opens the door for a funny game we could implement: Instead of the simulation creating sensor readings, we let the user try to click on the green dot (latent position), and the clicked position is the sensor reading. This way users can try to get a feel how Bayesian inference works.

Open questions

  1. If the model contains a simulation that uses imprecise integration, e.g. Euler integration, then the inference step may produce bad estimates if its simulation isn't called often enough. This may be the case if the visualization has a low frame rate and the sensor readings don't come in fast enough. One can solve this problem by polling not only at every visualization step, but also at a regular high enough rate, by combining the visualization clock with a constant rate clock using ParallelClock
  2. I don't understand how to pass an additional external parameter like the temperature to the buffer. If it is to be estimated, then this is fine (we use the existing workaround to make a stochastic process out of it), but if the buffer should use it as input, then I don't know how to pass it on, because the get method doesn't take input. One solution might be extending resampling buffers to do input and output at the same time. But I don't want to extend the framework just to accomodate one use case. With the existing framework, I can see two ad hoc solutions:
    • Using a global StateT, update the parameter from one component and read it from the buffer. This should work but it clutters the type level and doesn't feel idiomatic.
    • put could accept Either b p where p is the extra parameter. But then some extra organisation of data has to go on before the buffer and intermingle measurements and parameters together.
  3. I don't understand what the clocks for the resampling buffer should be. Are they completely arbitrary, or can model & observation likelihood contain clock information?
reubenharry commented 9 months ago

OK, this is cool! I have a few questions, just to make sure I understand the idea:

  1. By StochasticProcess, do you mean a ClSF? If so, why is it not just model :: StochasticProcess () a?
  2. Does model represent the prior? (i.e. the prior distribution over paths, before receiving information about observations)
  3. Viewed as a Moore machine, what is the exact type of a (inference) resampling buffer?

Relatedly, can we write a particle filter for a Moore machine? That is, something like:

newtype MooreT m a b = MooreT {runMooreT :: m (b, a -> MooreT m a b)}

pf :: Monad m => MooreT (PopulationT m) a b -> MooreT m a [(b, Log Double)]
pf (MooreT moore) = MooreT do
  x <- runPopulationT moore
  let pop = (\((x,y),z) -> (x,z)) <$> x
  let cont = sequence $ ( (\((x,y),z) -> y)) <$> x
  pure (pop, ...  )

And if so, can we write a particle filter that generalizes across Moore and Mealy machines (in the spirit of the machines library perhaps).

turion commented 9 months ago
1. By `StochasticProcess`, do you mean a ClSF?

Yes.

If so, why is it not just model :: StochasticProcess () a?

The first type argument to StochasticProcess is the time domain, e.g. UtcTime, Double etc., so in fact StochasticProcess time a doesn't have an extra input. See https://hackage.haskell.org/package/rhine-bayes-1.1/docs/FRP-Rhine-Bayes.html#t:StochasticProcess

2. Does `model` represent the prior? (i.e. the prior distribution over paths, before receiving information about observations)

Yes.

3. Viewed as a Moore machine, what is the exact type of a (inference) resampling buffer?

So a Moore machine is fundamentally asynchronous while a Mealy machine is synchronous:

data Mealy s a b = Mealy s (a -> s -> (s, b))

data Moore s a b = Moore s (a -> s -> s) (s -> b)

In a Mealy machine, you get one output for one input, in a Moore machine you can decide externally whether you want to feed input or extract output.

Now the above definition uses an initial encoding with an explicit state type. We can hide the state type by making it existential:

data MealyE a b = forall s . MealyE (Mealy s a b)
data MooreE a b = forall s . MooreE (Moore s a b)

Now there is also a final encoding of these concepts. The idea is to replace the implicit state type by the machine type itself. It's easiest to see this way, I find: There are some functions that step the machines:

step :: MealyE a b -> a -> (MealyE a b, b)

put :: MooreE a b -> a -> MooreE a b
get :: MooreE a b -> b

The idea is that the state is updated and the transition functions are kept. These functions form a complete API to deal with a machine, so one can also use it to define the machines ("final encoding"):

data MealyF a b = MealyF (a -> (MealyF a b, b))
data MooreF a b = MooreF
  { put :: a -> MooreF a b
  , get :: b
  }

In practice it turns out that one often wants to have side effects during these steps, so we generalize:

data MealyT m a b = MealyT (a -> m (MealyT m a b, b))
data MooreT m a b = MooreT
  { put :: a -> m (MooreT a b)
  , get :: m b
  }

This is where I diverge from the https://hackage.haskell.org/package/machines-0.7.3/docs/Data-Machine-MooreT.html convention, there the monad is around both put and get. I believe my approach is more general.

Now you'll recognise (or already knew before) that MealyT = MSF. And Behaviour m time a b is morally MealyT (ReaderT time m) a b, so we have an extra shared input which is the time.

But MooreT (ReaderT time m) is not quite ResamplingBuffer. A ResamplingBuffer can also update the state when getting a value:

data ResamplingBuffer m time a b = ResamplingBuffer
  { put :: a -> ReaderT time m (ResamplingBuffer m time a b)
  , get :: ReaderT time m (b, ReaderT time m (ResamplingBuffer m time a b))
  }
-- Yes, the library definition is slightly different for reasons which don't matter here I believe

The Bayesian idea behind this is that the passage of time is information which should be used to step inference forward.

Relatedly, can we write a particle filter for a Moore machine? That is, something like:

newtype MooreT m a b = MooreT {runMooreT :: m (b, a -> MooreT m a b)}

pf :: Monad m => MooreT (PopulationT m) a b -> MooreT m a [(b, Log Double)]
pf (MooreT moore) = MooreT do
  x <- runPopulationT moore
  let pop = (\((x,y),z) -> (x,z)) <$> x
  let cont = sequence $ ( (\((x,y),z) -> y)) <$> x
  pure (pop, ...  )

I'm sure we can, but I think it will be easier if we adopt the convention I proposed. Also, I found that the initial encoding also makes it much easier to think about the situation. So maybe start like this:

data Moore' s m a b = Moore'
  { state :: s
  , put :: a -> s -> m s
  , get :: s -> m b
  }

pf :: Int -> Moore' s m a b -> Moore' [s] m (Either a b) [s]
pf nParticles Moore' {state, put, get} = Moore'
  { state = replicate nParticles state
  , put = \case
      -- We've received input date, put just updates the states (proceed simulating prior)
      (Left a) states -> forM_ states $ put a
      -- We've received a measured observable, somehow condition the states on the measurement
      (Right b) states -> ...
  -- Return the current inference state. (Assuming all particles are normalized so the probabilities don't carry information)
  , get = return

The trouble is that when conditioning on measurements, monad-bayes only gives us the blunt tool of MonadFactor. We can of course condition onto b ==, but this can often be a null measure. It would be better if there was a function observe :: m b -> b -> m () that would score with the particular probability of b. I don't really know how to solve this in monad-bayes.

And if so, can we write a particle filter that generalizes across Moore and Mealy machines (in the spirit of the machines library perhaps).

I'm not sure how that would work. Anyways, maybe we can move the machines discussion elsewhere? Possibly to https://github.com/turion/rhine/discussions?

reubenharry commented 9 months ago

Thanks, this clarifies in particular the relevant difference between Moore and Mealy, which I was a bit unclear on.

So, to summarize, is the idea that we should use a Moore machine (roughly, but more specifically a resampling buffer) as our filter, since it is naturally asynchronous? That makes a lot of sense, if so. I guess the question I have then is just how the details look. For example, when does the resampling step happen? At every get? Or every put? Or either/neither?

I agree that the issue with monad-bayes not having observe is annoying.

In an ideal world, the type I'd want would be e.g:

asyncPF :: (MonadMeasure n, MonadDistribution m) => Behavior n cl a b -> ResamplingBuffer m cl1 cl2 b [(a, Log Double)]

or more concretely:

asyncPF :: (MonadDistribution m) => Behavior (PopulationT m) cl a b -> ResamplingBuffer m cl1 cl2 b [(a, Log Double)]

by analogy to the existing particle filter.

And yes, re. machines, that's a separate discussion entirely, I agree.

Also for another discussion, but I think relevant: do you have a version of the particle filter for the initial encoding of MealyT?

turion commented 9 months ago

So, to summarize, is the idea that we should use a Moore machine (roughly, but more specifically a resampling buffer) as our filter, since it is naturally asynchronous? That makes a lot of sense, if so.

Yes.

I guess the question I have then is just how the details look. For example, when does the resampling step happen? At every get? Or every put? Or either/neither?

On every put, I think, because this is the only situation where the probabilities are updated. One could use something like effective sample size (https://github.com/tweag/monad-bayes/pull/268) to prevent it from being called to often.

In an ideal world, the type I'd want would be e.g:

asyncPF :: (MonadMeasure n, MonadDistribution m) => Behavior n cl a b -> ResamplingBuffer m cl1 cl2 b [(a, Log Double)]

or more concretely:

asyncPF :: (MonadDistribution m) => Behavior (PopulationT m) cl a b -> ResamplingBuffer m cl1 cl2 b [(a, Log Double)]

by analogy to the existing particle filter.

That's a good point. You're pushing the task of conditioning onto the user that way. Which is ok, it is more general, and we can supply a helper function that takes the mere process (without input) and a likelihood function, and cobbles them together.

Also for another discussion, but I think relevant: do you have a version of the particle filter for the initial encoding of MealyT?

Yes: https://github.com/turion/rhine/blob/9f11e1c6027983a51102dd40b0f711dab00a5138/rhine-bayes/src/Data/MonadicStreamFunction/Bayes.hs#L38

I find it much more readable than the final encoding.

turion commented 9 months ago
asyncPF :: (MonadDistribution m) => Behavior (PopulationT m) cl a b -> ResamplingBuffer m cl1 cl2 b [(a, Log Double)]

I didn't read the type signature carefully. Now that I did, there are two points why it won't work straightforwardly like this:

  1. Actually this type signature isn't in analogy to the existing particle filter. Here, you are changing input and output (as you might expect in a Bayesian setting). But the existing particle filter doesn't do that. Since it offloads the task of conditioning to the user, it doesn't need to change input & output anymore.
  2. So with my comment from 1. you might expect the type signature to be asyncPF :: (MonadDistribution m) => Behavior (PopulationT m) cl a b -> ResamplingBuffer m cl1 cl2 a [(b, Log Double)] (with a & b swapped) instead. But this doesn't work either: the behaviour expects an a on every input, whereas the resampling buffer only receives one on put, not on get.

We can work around 2. if the behaviour can deal with absent values: asyncPF :: (MonadDistribution m) => Behavior (PopulationT m) cl (Maybe a) b -> ResamplingBuffer m cl1 cl2 a [(b, Log Double)] But then the user has to worry about conditioning only when the value is present. When we offer an API where prior and likelihood are separated, this organisation can be done by the framework. Maybe it's best to supply both options.

reubenharry commented 9 months ago

Yeah sorry, the switch of a and b was a typo. I meant asyncPF :: (MonadDistribution m) => Behavior (PopulationT m) cl a b -> ResamplingBuffer m cl1 cl2 a [(b, Log Double)]. But re your point 2, what about:

asyncPF :: (MonadDistribution m) => ResamplingBuffer (PopulationT m) cl1 cl2 a b -> ResamplingBuffer m cl1 cl2 a [(b, Log Double)]

in direct analogy to the existing pf.

turion commented 9 months ago

I guess this can be implemented, but what does it mean? I'm a bit puzzled how we construct a ResamplingBuffer (PopulationT m) cl1 cl2 a b. It has to condition on the input (and step the prior simulation), it has to only step the simulation on output. I think it's a bit weird to demand the user writing this. Nevertheless, it looks like a valuable intermediate concept.

turion commented 9 months ago

Also, the time stamps for resampling buffers can be confusing, since they run separately for put and get. I'd rather take this complexity away from the user as long as there is no need to expose it.

reubenharry commented 9 months ago

I guess this can be implemented, but what does it mean? I'm a bit puzzled how we construct a ResamplingBuffer (PopulationT m) cl1 cl2 a b. It has to condition on the input (and step the prior simulation), it has to only step the simulation on output. I think it's a bit weird to demand the user writing this. Nevertheless, it looks like a valuable intermediate concept.

Yeah, I mean it as an intermediate concept, in the sense that the machinery of the particle filter gets implemented by asyncPF, and then further tools are provided to help the user construct their unnormalized asynchronous stochastic process of type MonadMeasure m => ResamplingBuffer m cl1 cl2 a b. But maybe this doesn't make sense - I'll play with your code and get a more concrete sense of how this all works