tweag / monad-bayes

A library for probabilistic programming in Haskell.
MIT License
409 stars 62 forks source link

*** Exception: System.Random.MWC.Distributions.categorical: bad weights! #27

Closed idontgetoutmuch closed 7 years ago

idontgetoutmuch commented 7 years ago
{-# OPTIONS_GHC -Wall     #-}

{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE BangPatterns #-}

module Main where

import Control.Monad.Bayes.LogDomain
import Control.Monad.Bayes.Primitive
import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Population
import Control.Monad.Bayes.Conditional
import Control.Monad.Bayes.Inference

import Numeric.GSL.ODE
import Numeric.LinearAlgebra hiding ( step, Vector )

import qualified Data.Vector as V
import Data.Vector ( Vector )
import Control.Monad

import Control.Monad.Bayes.Sampler

import Debug.Trace

-- number of particles used in PMMH
n_particles :: Int
n_particles = 5

-- model data and constants
h :: Double
h = 0.1

k1', b', d', k2', c' :: Double
k1' = 2.0e2  -- Hare carrying capacity
b'  = 2.0e-2 -- Hare death rate per lynx
d'  = 4.0e-1 -- Lynx death rate
k2' = 2.0e1  -- Lynx carrying capacity
c'  = 4.0e-3 -- Lynx birth rate per hare

-- a :: Double
-- a = 0.5

-- type of simulation state
data S = S {p :: Double, z :: Double, log_alpha :: Double}

parameters :: (MonadDist m, CustomReal m ~ Double) => m (Double, Double)
parameters = do
  mu <- uniform 0 1
  sigma <- uniform 0 5
  return (mu,sigma)

-- initial state of the simulation
initial_state :: (MonadBayes m, CustomReal m ~ Double) => (Double, Double) -> m S
initial_state (mu, sigma) = do
  log_p_init <- normal (log 100) 0.2
  log_z_init <- normal (log  50) 0.1
  log_alpha_init <- normal (log mu) sigma
  return (S (exp log_p_init) (exp log_z_init) log_alpha_init)

-- transition model
transition :: (MonadBayes m, CustomReal m ~ Double) => (Double, Double) -> S -> m S
transition params state = do
  w <- normal 0 (sqrt h)
  -- TODO: ODE solver updates state here
  let a = exp (log_alpha state)
      m = solPp a (p state) (z state)
      newP = m ! 1 ! 0
      newZ = m ! 1 ! 1
      sigma = snd params
      newLog_alpha = -sigma * sigma * h / 2 - sigma * w
  return $ S newP newZ newLog_alpha
    where
      ppOde a k1 b d k2 c _t [pp, zz] =
        [
          a * pp * (1 - pp / k1) - b * pp * zz
        , -d * zz * (1 + zz / k2) + c * pp * zz
        ]
      ppOde _a _k1 _b _d _k2 _c _t vars =
        error $ "ppOde called with: " ++ show (length vars) ++ " variables"

      solPp a x y = odeSolve (ppOde a k1' b' d' k2' c')
                             [x, y]
                             (fromList [0.0, h])

-- full simulation given parameters, returns the parameters
model :: (MonadBayes m, CustomReal m ~ Double) => (Double, Double) -> m (Double, Double)
model params = foldl step (initial_state params) obs >> return params where
  step old_state p_obs = do
    state <- old_state
    new_state <- transition params state
    observe (Continuous (Normal (log (p state)) 0.1)) (log p_obs)
    return new_state

-- full model with particle filter, returns posterior over model parameters
full_model :: (MonadBayes m, CustomReal m ~ Double) => m (Double, Double)
full_model = parameters >>= (collapse . smc (length obs) n_particles . model)

-- PMMH transition kernel
-- monad-bayes does not currently have truncated normals
pmmh_kernel :: (MonadDist m, CustomReal m ~ Double) => [Double] -> m [Double]
pmmh_kernel [mu, sigma] = do
  mu' <- normal mu 1
  sigma' <- normal sigma 1
  return [mu', sigma']
pmmh_kernel xs = error $ "pmmh_kernel called with: " ++ show (length xs) ++ " variables"

-- full PMMH transition step
pmmh_step :: (MonadDist m, CustomReal m ~ Double) => [Double] -> m [Double]
pmmh_step params = do
  params' <- trace ("Params: " ++ show params) $ pmmh_kernel params
  let kernel_density  = unsafeContJointDensity (pmmh_kernel params) params'
  let kernel_density' = unsafeContJointDensity (pmmh_kernel params') params
  pm_density  <- pseudoDensity full_model (map Just params , [])
  pm_density' <- pseudoDensity full_model (map Just params', [])
  let mh_ratio = pm_density' * kernel_density' / (pm_density * kernel_density)
  accept <- bernoulli (min 1 (fromLogDomain mh_ratio))
  trace ("Accept: " ++ show accept) $ return (if accept then params' else params)

iterateNM :: Monad m => Int -> (a -> m a) -> a -> m (Vector a)
iterateNM n f x
  | n == 0 = return $ V.singleton x
  | otherwise = do
      y <- f x
      liftM (y `V.cons`) (iterateNM (n - 1) f y)

main :: IO ()
main = do
  ps1 <- sampleIOfixed $ do ps <- parameters
                            foo <- iterateNM 2 pmmh_step [fst ps, snd ps]
                            return foo
  let mus = V.map (!!0) $ V.drop 1 ps1
  putStrLn $ show (sum mus / 1)
  let sigmas = V.map (!!1) $ V.drop 1 ps1
  putStrLn $ show (sum sigmas / 1)

obs :: [Double]
obs = take 3 $ [
  76.2403679754159,
  65.0098784873532,
  72.7307834964011,
  66.9032651694089,
  69.4511465390588,
  64.2803362618095,
  56.2848299444718,
  55.38783776159,
  58.4038970077208,
  65.1111562800144,
  61.8487358181798,
  54.6268190304816,
  61.7832255452353,
  57.2580992086863,
  51.6506692415061,
  53.513024643873,
  53.2516281763045,
  52.8808601383571,
  50.8172783401389,
  56.4015933848598,
  51.9694880143296,
  53.6405052882993,
  54.575606130029,
  59.3492395433371,
  60.0888983492202,
  44.4110818026195,
  54.4982828776438,
  73.4084712052465,
  54.7644338856397,
  65.9215790688941,
  64.4989883914755,
  80.9772033868803,
  74.4092779374009,
  69.1141635364459,
  67.4281988428452,
  73.2048837536519,
  68.120224957282,
  62.0069646761111,
  80.1892530691497,
  61.1742005270423,
  79.8803349040011,
  83.2200370887438,
  77.5918040799712,
  72.5731170839407,
  77.4547803718811,
  76.6217772509249,
  73.676388095968,
  85.6784703161388,
  104.170224855567,
  89.2220461329041,
  97.8557959956232,
  108.537541058624,
  80.7233691710187,
  101.184198386732,
  106.119103176147,
  100.475879494068,
  125.13032727884,
  91.0701758431548,
  88.8812050725806,
  100.679238971492,
  110.589530407881,
  89.6049957679831,
  103.926588273376,
  100.068112237358,
  115.512346913051,
  113.109941769312,
  114.117083726776,
  126.223426393656,
  116.655177064036,
  119.47938528113,
  119.673062865914,
  122.17342562363,
  115.260394834536,
  118.68195366763,
  124.943355707336,
  151.936975691988,
  114.130007656096,
  148.946645952698,
  144.565611912741,
  118.759397178281,
  126.673416199079,
  142.600114603459,
  124.392431912102,
  123.845615416597,
  131.162727200371,
  127.530085947053,
  149.07114126,
  123.535118463953,
  166.415187417087,
  131.719924167561,
  139.991005535832,
  130.070166508874,
  141.353801701055,
  136.634266000686,
  120.874938386323,
  135.843010192177,
  158.44951183627,
  150.728722863145,
  176.881378505555,
  133.774741754552,
  137.445456445669,
  124.34236836753,
  152.292445573619,
  145.969440861177,
  131.636047513673,
  146.963132155529,
  151.695165885035,
  165.977302905119,
  142.031450539543,
  164.846194851848,
  168.013420195697,
  152.767971695764,
  143.851647250068,
  160.436734291349,
  152.302294176293,
  155.067612585515,
  175.247982500084,
  136.897735270247,
  151.97211044291,
  160.887803973557,
  136.486833382118,
  158.883061549789,
  152.458944513575,
  149.023829176969,
  155.340944696967,
  157.829634728021,
  158.633421934879,
  132.036887061947,
  135.101844477751,
  153.29922033573,
  158.052850728581,
  143.369824868863,
  157.714843025493,
  144.325777788735,
  131.0549741648,
  144.337275322618,
  145.026723249829,
  160.030142868999,
  159.542825278056,
  147.87991947625,
  126.381489671728,
  145.226447360054,
  136.894318063363,
  159.268060107793,
  119.239998434002,
  132.788094593534,
  150.663078585921,
  144.430826746593,
  165.173155950077,
  147.929951578988,
  157.760394555343,
  131.053973697495,
  165.580489628044,
  146.718090811465,
  142.773179403759,
  146.392167581097,
  146.025983883888,
  141.47610191321,
  134.214216068753,
  119.248789990169,
  135.031489547166,
  120.561743419339,
  167.557212051444,
  148.168532426848,
  136.547249833344,
  152.416629324237,
  173.5981255812,
  132.895016695342,
  138.669418114316,
  167.112761599289,
  119.007943156896,
  144.496358594722,
  138.515189246643,
  158.554459930596,
  136.975247626486,
  129.513055665373,
  169.020473191163,
  156.116991826441,
  129.524981775373,
  148.85363969582,
  157.801922552141,
  157.577368682676,
  140.430975061626,
  136.205715759791,
  158.040217033728,
  125.493638253065,
  130.114292183725,
  142.750827896278,
  174.071675002185,
  155.168991141177,
  165.5803346601,
  133.967097475562,
  133.930047497004,
  127.601345216493,
  135.602745924639,
  163.280148579137,
  150.634881342163,
  144.898580660782,
  123.221738509521,
  170.394076047434,
  152.596479493751,
  156.695148025635,
  125.317613954204,
  152.957730717693,
  168.012683439549,
  120.716643641628,
  150.531587838906,
  145.104736014324,
  124.182149449396,
  133.269771329542,
  114.976356844067,
  139.860395302311,
  137.801555771354,
  139.678968034677,
  134.434880549483,
  138.011761267802,
  125.439482936179,
  151.74813237231,
  131.847695399538,
  123.319231908423,
  152.547130543284,
  125.491719298048,
  137.096981763213,
  138.768540718737,
  181.5283483391,
  126.098802734704,
  163.482620601302,
  128.342617368436,
  141.586517467223,
  119.474015267941,
  120.9364132138,
  112.984981940618,
  141.341437153727,
  135.160684269048,
  134.509200686011,
  163.319201985027,
  135.415854723738,
  159.063490681458,
  153.148663733748,
  149.026302196215,
  164.826019093328,
  123.071508271513,
  176.387627783159,
  149.483954577518,
  123.02095153944,
  156.650270662224,
  181.195533023262,
  141.238386891745,
  134.701979169655,
  131.340667101923,
  144.548953780858,
  131.588662675565,
  144.541170469088,
  125.578498701588,
  164.74439785325,
  127.766537075735,
  133.273837701266,
  145.93004226061,
  123.374900614912,
  151.596803755421,
  144.894219796883,
  126.998541138479,
  147.035896225008,
  136.476333130728,
  135.608414577445,
  131.807584253206,
  127.366264864742,
  134.964091100405,
  143.421168142046,
  142.054495699452,
  151.438892001945,
  143.05917211411,
  136.712518442789,
  120.003089311654,
  140.869561010692,
  135.058679734824,
  134.081730533486,
  139.43057733862,
  162.26933796043,
  136.471725650913,
  125.873826665898,
  151.097217418369,
  121.162250703819,
  150.767682408018,
  130.606266701801,
  105.20437547654,
  132.359328972528,
  116.989153980971,
  148.954925443656,
  142.105725642533,
  143.553119435633,
  127.751657026012,
  147.148424004414,
  114.85477950561,
  139.694548684757,
  146.171259764561,
  127.789841217549,
  163.168916396005,
  147.231424950466,
  125.058830151811,
  119.074143576494,
  133.536702795857,
  137.262453768502,
  179.672247639064,
  118.207552485961,
  136.618926147412,
  142.352972731303,
  136.479602357591,
  118.750751452668,
  147.986276828413,
  147.411481380629,
  144.052405584174,
  146.153708105283,
  131.557481669184,
  134.453090884824,
  153.443307349816,
  146.984446479185,
  138.421394583867,
  132.707447826652,
  136.516554357634,
  122.191521371155,
  134.80783216139,
  155.447901778778,
  136.926153057977,
  125.184160305696,
  122.367097582219,
  134.726963096563,
  138.72230909606,
  133.531816345683,
  138.817134160604,
  151.680755115041,
  143.081423812855,
  144.507716078662,
  150.54905673819,
  164.395431529991,
  136.035307291772,
  135.963680724784,
  159.088071380666,
  115.544651209282,
  126.449866575417,
  145.861556307966,
  121.299809587301,
  135.041076468273,
  115.420027815205,
  120.649592931961,
  149.222665704902,
  136.787170312266,
  117.577955349667,
  123.09131466115,
  149.613666057669,
  127.542230989168,
  148.826516120672,
  172.024769211347,
  119.163511852413,
  168.846806388603,
  122.577905606408,
  163.691820942384,
  121.152394606119,
  125.364015734805,
  134.053429227603,
  116.102707413675,
  115.932034184824,
  157.856002757677,
  120.417539311501,
  159.448081983552,
  146.523081800403,
  142.244157886188,
  147.49875516021,
  149.794865803213,
  142.424753997512,
  155.377391656754,
  133.457569330548,
  157.979204500375,
  132.867976837153,
  150.766104208022,
  151.574164752913,
  142.79494077867,
  135.300935952747,
  155.087644429195,
  134.339799297657,
  117.780905975319,
  141.325010750032,
  137.141299531074,
  150.973508888167,
  135.037762357837,
  145.031582826828,
  138.551290507779,
  144.969403511833,
  162.819743874685,
  132.441267184392,
  131.380740568251,
  141.527642238539,
  124.238855213983,
  155.588321253241,
  145.068004565977,
  120.849885116319,
  138.600747823775,
  138.767637532808,
  149.658261434865,
  153.218173715505,
  123.173078726806,
  144.980628125895,
  133.601607881188,
  143.600386923789,
  134.99375013592,
  178.108618865062,
  162.589164120014,
  149.801157966121,
  124.574389244282,
  159.749711094236,
  154.462697945675,
  124.745887246537,
  134.619145288027,
  132.829029331846,
  133.515016847945,
  130.370138175619,
  134.058600191345,
  138.967206000683,
  144.844845114948,
  140.150588785047,
  168.155026956341,
  105.663614240897,
  188.849879696026,
  134.066594679321,
  118.875307519085,
  152.36326566797,
  126.208445705652,
  168.297022523569,
  152.27258877451,
  140.382228072818,
  119.424396673087,
  161.465352774379,
  151.655127807313,
  135.515116635163,
  151.532357393466,
  151.810511351233,
  136.361410128331,
  153.741012514592,
  147.431121676277,
  157.91915964734,
  120.327187813985,
  120.171500909668,
  142.428896431138,
  170.204819149259,
  158.214506571837,
  138.733887337899,
  123.808097192822,
  152.903612727097,
  140.769759838882,
  119.667303148295,
  129.45932949513,
  149.16660912873,
  150.254944400129,
  128.273570440445,
  153.278564595965,
  121.864809541511,
  161.542701435884,
  156.946807512148,
  158.881028441797,
  125.569197718704,
  137.940620707105,
  129.90603859364,
  152.757290961669,
  127.74685675702,
  137.183947553826,
  150.462794723172,
  140.929843406962,
  164.975435125998,
  182.344447097492,
  156.919848441299,
  150.750704137777,
  141.693725748231,
  110.676915031473,
  173.230786634625,
  148.248841760484,
  119.716751121111,
  143.947798672895,
  130.440507876832,
  131.964077776465,
  134.646898631033,
  100.778877180885,
  124.348744604234,
  127.180330688462,
  158.146060814462,
  171.358807156152,
  131.864294544909,
  143.085939246005,
  149.586050733912,
  132.2767374166,
  129.632802049237,
  137.02205312678,
  146.243848392623,
  125.713745961482,
  132.545872818923,
  129.355428299295,
  150.515004650034,
  149.481788573074,
  154.478850748833,
  137.329325663629,
  127.9504469047,
  133.350229974433]
idontgetoutmuch commented 7 years ago

I have tried putting in trace statements but

-- | A properly weighted single sample, that is one picked at random according
-- to the weights, with an estimator of the model evidence.
proper :: MonadDist m => Population m a -> m (a,LogDomain (CustomReal m))
proper m = do
  pop <- runPopulation m
  let (xs, ps) = unzip pop
  let z = sum ps
  index <- trace ("Population: " ++ show ps) $ discrete $ map fromLogDomain ps
  let x = xs !! index
  return (x,z)

gives

../../monad-bayes/src/Control/Monad/Bayes/Population.hs:130:37: error:
    • Could not deduce (Show (CustomReal m))
        arising from a use of ‘show’
      from the context: MonadDist m
        bound by the type signature for:
                   proper :: MonadDist m =>
                             Population m a -> m (a, LogDomain (CustomReal m))
        at ../../monad-bayes/src/Control/Monad/Bayes/Population.hs:125:1-73
    • In the second argument of ‘(++)’, namely ‘show ps’
      In the first argument of ‘trace’, namely
        ‘("Population: " ++ show ps)’
      In the expression: trace ("Population: " ++ show ps)

It's not clear to me what changes I would need to make to get CustomReal m to be showable.

It turns out this works

  index <- trace ("Population: " ++ show (map toRational ps)) $ discrete $ map fromLogDomain ps

Some more trace statements later

    categorical :: [(a,CustomReal m)] -> m a
    categorical d = do
      i <- trace ("Categorical: " ++ show (map (fr . toRational . snd) d)) $ discrete (map snd d)
      return (fst (d !! i))
Categorical: [Infinity,Infinity,Infinity,Infinity,Infinity]

So this is causing the NaNs but what is causing the Infinities?

adscib commented 7 years ago

I've added some code to improve numerical stability, but I was unable to run your example - the transition model seems to be taking forever to execute.

idontgetoutmuch commented 7 years ago

I haven't tried your changes yet but I did this

foreign import call "feenableexcept" enableFloatException :: Int -> IO Int

allFloatExceptions :: Int
allFloatExceptions = -- 1 {-INVALID-} +
                     4 {-DIVBYZERO-} + 0
                     -- 8 {-OVERFLOW-}
                     -- 16 {-UNDERFLOW-}

main :: IO ()
main = do
  _ <- enableFloatException allFloatExceptions
  ps1 <- sampleIOfixed $ do ps <- parameters
                            foo <- iterateNM 2 pmmh_step [fst ps, snd ps]
                            return foo
  let mus = V.map (!!0) $ V.drop 1 ps1
  putStrLn $ show (sum mus / 1)
  let sigmas = V.map (!!1) $ V.drop 1 ps1
  putStrLn $ show (sum sigmas / 1)

and got

~/Dropbox/Private/Stochastic/demo $ ./app/Main 
Params: [2.481036288296201e-2,3.704320339726504]
Resample: [5.073988065454816e-2,8.813732722277414e-4,1.3167705703870092e-2,0.7900797866431766,0.7945813585078162]
Categorical: [5.073988065454816e-2,8.813732722277414e-4,1.3167705703870092e-2,0.7900797866431766,0.7945813585078162]
[5.073988065454816e-2,8.813732722277414e-4,1.3167705703870092e-2,0.7900797866431766,0.7945813585078162]
[5.073988065454816e-2,8.813732722277414e-4,1.3167705703870092e-2,0.7900797866431766,0.7945813585078162]
[5.073988065454816e-2,8.813732722277414e-4,1.3167705703870092e-2,0.7900797866431766,0.7945813585078162]
[5.073988065454816e-2,8.813732722277414e-4,1.3167705703870092e-2,0.7900797866431766,0.7945813585078162]
[5.073988065454816e-2,8.813732722277414e-4,1.3167705703870092e-2,0.7900797866431766,0.7945813585078162]
Floating point exception: 8

So quite early on we are seeing a divide by zero. I am not convinced this is numerical stability. It could be we are asking the solver to do something daft. I am going to try your changes. If that doesn't work I will put in a simple Euler solver rather than use gsl. I don't believe my model is stiff so this should make little difference and we will have gotten rid of one more dependency.

idontgetoutmuch commented 7 years ago

This uses a simple Euler scheme to update the state so should (ahem) not take forever to execute. I still get a floating point exception even using the latest commit (divide by 0).

{-# OPTIONS_GHC -Wall     #-}

{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE BangPatterns #-}

module Main where

import Control.Monad.Bayes.LogDomain
import Control.Monad.Bayes.Primitive
import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Population
import Control.Monad.Bayes.Conditional
import Control.Monad.Bayes.Inference

-- import Numeric.GSL.ODE
import Numeric.LinearAlgebra hiding ( step, Vector )
import qualified Numeric.LinearAlgebra as SV

import qualified Data.Vector as V
import Data.Vector ( Vector )
import Control.Monad

import Control.Monad.Bayes.Sampler

import Debug.Trace

-- number of particles used in PMMH
n_particles :: Int
n_particles = 5

-- model data and constants
h :: Double
h = 0.1

k1', b', d', k2', c' :: Double
k1' = 2.0e2  -- Hare carrying capacity
b'  = 2.0e-2 -- Hare death rate per lynx
d'  = 4.0e-1 -- Lynx death rate
k2' = 2.0e1  -- Lynx carrying capacity
c'  = 4.0e-3 -- Lynx birth rate per hare

-- a :: Double
-- a = 0.5

-- type of simulation state
data S = S {p :: Double, z :: Double, log_alpha :: Double}
  deriving Show

parameters :: (MonadDist m, CustomReal m ~ Double) => m (Double, Double)
parameters = do
  mu <- uniform 0 1
  sigma <- uniform 0 5
  return (mu,sigma)

-- initial state of the simulation
initial_state :: (MonadBayes m, CustomReal m ~ Double) => (Double, Double) -> m S
initial_state (mu, sigma) = do
  log_p_init <- normal (log 100) 0.2
  log_z_init <- normal (log  50) 0.1
  log_alpha_init <- normal (log mu) sigma
  return (S (exp log_p_init) (exp log_z_init) log_alpha_init)

-- transition model
transition :: (MonadBayes m, CustomReal m ~ Double) => (Double, Double) -> S -> m S
transition params state = trace (show params ++ " " ++ show state) $ do
  w <- normal 0 (sqrt h)
  -- TODO: ODE solver updates state here
  let a = exp (log_alpha state)
      -- m = solPp a (p state) (z state)
      m = euler a (p state) (z state)
      -- newP = m ! 1 ! 0
      -- newZ = m ! 1 ! 1
      newP = m ! 0
      newZ = m ! 1
      sigma = snd params
      newLog_alpha = -sigma * sigma * h / 2 - sigma * w
  return $ S newP newZ newLog_alpha
    where
      -- ppOde a k1 b d k2 c _t [pp, zz] =
      --   [
      --     a * pp * (1 - pp / k1) - b * pp * zz
      --   , -d * zz * (1 + zz / k2) + c * pp * zz
      --   ]
      -- ppOde _a _k1 _b _d _k2 _c _t vars =
      --   error $ "ppOde called with: " ++ show (length vars) ++ " variables"

      -- solPp a x y = odeSolve (ppOde a k1' b' d' k2' c')
      --                        [x, y]
      --                        (fromList [0.0, h])

euler :: Double -> Double -> Double -> SV.Vector Double
euler a prevP prevZ = fromList [newP, newZ]
  where
    newP = prevP + a * prevP * (1 - prevP / k1') - b' * prevP * prevZ
    newZ =  -d' * prevZ * (1 + prevZ / k2') + c' * prevP * prevZ

-- full simulation given parameters, returns the parameters
model :: (MonadBayes m, CustomReal m ~ Double) => (Double, Double) -> m (Double, Double)
model params = foldl step (initial_state params) obs >> return params where
  step old_state p_obs = do
    state <- old_state
    new_state <- transition params state
    observe (Continuous (Normal (log (p state)) 0.1)) (log p_obs)
    return new_state

-- full model with particle filter, returns posterior over model parameters
full_model :: (MonadBayes m, CustomReal m ~ Double) => m (Double, Double)
full_model = parameters >>= (collapse . smc (length obs) n_particles . model)

-- PMMH transition kernel
-- monad-bayes does not currently have truncated normals
pmmh_kernel :: (MonadDist m, CustomReal m ~ Double) => [Double] -> m [Double]
pmmh_kernel [mu, sigma] = do
  mu' <- normal mu 1
  sigma' <- normal sigma 1
  return [mu', sigma']
pmmh_kernel xs = error $ "pmmh_kernel called with: " ++ show (length xs) ++ " variables"

-- full PMMH transition step
pmmh_step :: (MonadDist m, CustomReal m ~ Double) => [Double] -> m [Double]
pmmh_step params = do
  params' <- trace ("Params: " ++ show params) $ pmmh_kernel params
  let kernel_density  = unsafeContJointDensity (pmmh_kernel params) params'
  let kernel_density' = unsafeContJointDensity (pmmh_kernel params') params
  pm_density  <- pseudoDensity full_model (map Just params , [])
  pm_density' <- pseudoDensity full_model (map Just params', [])
  let mh_ratio = pm_density' * kernel_density' / (pm_density * kernel_density)
  accept <- bernoulli (min 1 (fromLogDomain mh_ratio))
  trace ("Accept: " ++ show accept) $ return (if accept then params' else params)

iterateNM :: Monad m => Int -> (a -> m a) -> a -> m (Vector a)
iterateNM n f x
  | n == 0 = return $ V.singleton x
  | otherwise = do
      y <- f x
      liftM (y `V.cons`) (iterateNM (n - 1) f y)

foreign import ccall "geenableexcept" enableFloatException :: Int -> IO Int

allFloatExceptions :: Int
allFloatExceptions = -- 1 {-INVALID-} +
                     4 {-DIVBYZERO-} + 0
                     -- 8 {-OVERFLOW-}
                     -- 16 {-UNDERFLOW-}

main :: IO ()
main = do
  _ <- enableFloatException allFloatExceptions
  ps1 <- sampleIOfixed $ do ps <- parameters
                            foo <- iterateNM 2 pmmh_step [fst ps, snd ps]
                            return foo
  let mus = V.map (!!0) $ V.drop 1 ps1
  putStrLn $ show (sum mus / 1)
  let sigmas = V.map (!!1) $ V.drop 1 ps1
  putStrLn $ show (sum sigmas / 1)

obs :: [Double]
obs = take 3 $ [
  76.2403679754159,
  65.0098784873532,
  72.7307834964011,
  66.9032651694089,
  69.4511465390588,
  64.2803362618095,
  56.2848299444718,
  55.38783776159,
  58.4038970077208,
  65.1111562800144,
  61.8487358181798,
  54.6268190304816,
  61.7832255452353,
  57.2580992086863,
  51.6506692415061,
  53.513024643873,
  53.2516281763045,
  52.8808601383571,
  50.8172783401389,
  56.4015933848598,
  51.9694880143296,
  53.6405052882993,
  54.575606130029,
  59.3492395433371,
  60.0888983492202,
  44.4110818026195,
  54.4982828776438,
  73.4084712052465,
  54.7644338856397,
  65.9215790688941,
  64.4989883914755,
  80.9772033868803,
  74.4092779374009,
  69.1141635364459,
  67.4281988428452,
  73.2048837536519,
  68.120224957282,
  62.0069646761111,
  80.1892530691497,
  61.1742005270423,
  79.8803349040011,
  83.2200370887438,
  77.5918040799712,
  72.5731170839407,
  77.4547803718811,
  76.6217772509249,
  73.676388095968,
  85.6784703161388,
  104.170224855567,
  89.2220461329041,
  97.8557959956232,
  108.537541058624,
  80.7233691710187,
  101.184198386732,
  106.119103176147,
  100.475879494068,
  125.13032727884,
  91.0701758431548,
  88.8812050725806,
  100.679238971492,
  110.589530407881,
  89.6049957679831,
  103.926588273376,
  100.068112237358,
  115.512346913051,
  113.109941769312,
  114.117083726776,
  126.223426393656,
  116.655177064036,
  119.47938528113,
  119.673062865914,
  122.17342562363,
  115.260394834536,
  118.68195366763,
  124.943355707336,
  151.936975691988,
  114.130007656096,
  148.946645952698,
  144.565611912741,
  118.759397178281,
  126.673416199079,
  142.600114603459,
  124.392431912102,
  123.845615416597,
  131.162727200371,
  127.530085947053,
  149.07114126,
  123.535118463953,
  166.415187417087,
  131.719924167561,
  139.991005535832,
  130.070166508874,
  141.353801701055,
  136.634266000686,
  120.874938386323,
  135.843010192177,
  158.44951183627,
  150.728722863145,
  176.881378505555,
  133.774741754552,
  137.445456445669,
  124.34236836753,
  152.292445573619,
  145.969440861177,
  131.636047513673,
  146.963132155529,
  151.695165885035,
  165.977302905119,
  142.031450539543,
  164.846194851848,
  168.013420195697,
  152.767971695764,
  143.851647250068,
  160.436734291349,
  152.302294176293,
  155.067612585515,
  175.247982500084,
  136.897735270247,
  151.97211044291,
  160.887803973557,
  136.486833382118,
  158.883061549789,
  152.458944513575,
  149.023829176969,
  155.340944696967,
  157.829634728021,
  158.633421934879,
  132.036887061947,
  135.101844477751,
  153.29922033573,
  158.052850728581,
  143.369824868863,
  157.714843025493,
  144.325777788735,
  131.0549741648,
  144.337275322618,
  145.026723249829,
  160.030142868999,
  159.542825278056,
  147.87991947625,
  126.381489671728,
  145.226447360054,
  136.894318063363,
  159.268060107793,
  119.239998434002,
  132.788094593534,
  150.663078585921,
  144.430826746593,
  165.173155950077,
  147.929951578988,
  157.760394555343,
  131.053973697495,
  165.580489628044,
  146.718090811465,
  142.773179403759,
  146.392167581097,
  146.025983883888,
  141.47610191321,
  134.214216068753,
  119.248789990169,
  135.031489547166,
  120.561743419339,
  167.557212051444,
  148.168532426848,
  136.547249833344,
  152.416629324237,
  173.5981255812,
  132.895016695342,
  138.669418114316,
  167.112761599289,
  119.007943156896,
  144.496358594722,
  138.515189246643,
  158.554459930596,
  136.975247626486,
  129.513055665373,
  169.020473191163,
  156.116991826441,
  129.524981775373,
  148.85363969582,
  157.801922552141,
  157.577368682676,
  140.430975061626,
  136.205715759791,
  158.040217033728,
  125.493638253065,
  130.114292183725,
  142.750827896278,
  174.071675002185,
  155.168991141177,
  165.5803346601,
  133.967097475562,
  133.930047497004,
  127.601345216493,
  135.602745924639,
  163.280148579137,
  150.634881342163,
  144.898580660782,
  123.221738509521,
  170.394076047434,
  152.596479493751,
  156.695148025635,
  125.317613954204,
  152.957730717693,
  168.012683439549,
  120.716643641628,
  150.531587838906,
  145.104736014324,
  124.182149449396,
  133.269771329542,
  114.976356844067,
  139.860395302311,
  137.801555771354,
  139.678968034677,
  134.434880549483,
  138.011761267802,
  125.439482936179,
  151.74813237231,
  131.847695399538,
  123.319231908423,
  152.547130543284,
  125.491719298048,
  137.096981763213,
  138.768540718737,
  181.5283483391,
  126.098802734704,
  163.482620601302,
  128.342617368436,
  141.586517467223,
  119.474015267941,
  120.9364132138,
  112.984981940618,
  141.341437153727,
  135.160684269048,
  134.509200686011,
  163.319201985027,
  135.415854723738,
  159.063490681458,
  153.148663733748,
  149.026302196215,
  164.826019093328,
  123.071508271513,
  176.387627783159,
  149.483954577518,
  123.02095153944,
  156.650270662224,
  181.195533023262,
  141.238386891745,
  134.701979169655,
  131.340667101923,
  144.548953780858,
  131.588662675565,
  144.541170469088,
  125.578498701588,
  164.74439785325,
  127.766537075735,
  133.273837701266,
  145.93004226061,
  123.374900614912,
  151.596803755421,
  144.894219796883,
  126.998541138479,
  147.035896225008,
  136.476333130728,
  135.608414577445,
  131.807584253206,
  127.366264864742,
  134.964091100405,
  143.421168142046,
  142.054495699452,
  151.438892001945,
  143.05917211411,
  136.712518442789,
  120.003089311654,
  140.869561010692,
  135.058679734824,
  134.081730533486,
  139.43057733862,
  162.26933796043,
  136.471725650913,
  125.873826665898,
  151.097217418369,
  121.162250703819,
  150.767682408018,
  130.606266701801,
  105.20437547654,
  132.359328972528,
  116.989153980971,
  148.954925443656,
  142.105725642533,
  143.553119435633,
  127.751657026012,
  147.148424004414,
  114.85477950561,
  139.694548684757,
  146.171259764561,
  127.789841217549,
  163.168916396005,
  147.231424950466,
  125.058830151811,
  119.074143576494,
  133.536702795857,
  137.262453768502,
  179.672247639064,
  118.207552485961,
  136.618926147412,
  142.352972731303,
  136.479602357591,
  118.750751452668,
  147.986276828413,
  147.411481380629,
  144.052405584174,
  146.153708105283,
  131.557481669184,
  134.453090884824,
  153.443307349816,
  146.984446479185,
  138.421394583867,
  132.707447826652,
  136.516554357634,
  122.191521371155,
  134.80783216139,
  155.447901778778,
  136.926153057977,
  125.184160305696,
  122.367097582219,
  134.726963096563,
  138.72230909606,
  133.531816345683,
  138.817134160604,
  151.680755115041,
  143.081423812855,
  144.507716078662,
  150.54905673819,
  164.395431529991,
  136.035307291772,
  135.963680724784,
  159.088071380666,
  115.544651209282,
  126.449866575417,
  145.861556307966,
  121.299809587301,
  135.041076468273,
  115.420027815205,
  120.649592931961,
  149.222665704902,
  136.787170312266,
  117.577955349667,
  123.09131466115,
  149.613666057669,
  127.542230989168,
  148.826516120672,
  172.024769211347,
  119.163511852413,
  168.846806388603,
  122.577905606408,
  163.691820942384,
  121.152394606119,
  125.364015734805,
  134.053429227603,
  116.102707413675,
  115.932034184824,
  157.856002757677,
  120.417539311501,
  159.448081983552,
  146.523081800403,
  142.244157886188,
  147.49875516021,
  149.794865803213,
  142.424753997512,
  155.377391656754,
  133.457569330548,
  157.979204500375,
  132.867976837153,
  150.766104208022,
  151.574164752913,
  142.79494077867,
  135.300935952747,
  155.087644429195,
  134.339799297657,
  117.780905975319,
  141.325010750032,
  137.141299531074,
  150.973508888167,
  135.037762357837,
  145.031582826828,
  138.551290507779,
  144.969403511833,
  162.819743874685,
  132.441267184392,
  131.380740568251,
  141.527642238539,
  124.238855213983,
  155.588321253241,
  145.068004565977,
  120.849885116319,
  138.600747823775,
  138.767637532808,
  149.658261434865,
  153.218173715505,
  123.173078726806,
  144.980628125895,
  133.601607881188,
  143.600386923789,
  134.99375013592,
  178.108618865062,
  162.589164120014,
  149.801157966121,
  124.574389244282,
  159.749711094236,
  154.462697945675,
  124.745887246537,
  134.619145288027,
  132.829029331846,
  133.515016847945,
  130.370138175619,
  134.058600191345,
  138.967206000683,
  144.844845114948,
  140.150588785047,
  168.155026956341,
  105.663614240897,
  188.849879696026,
  134.066594679321,
  118.875307519085,
  152.36326566797,
  126.208445705652,
  168.297022523569,
  152.27258877451,
  140.382228072818,
  119.424396673087,
  161.465352774379,
  151.655127807313,
  135.515116635163,
  151.532357393466,
  151.810511351233,
  136.361410128331,
  153.741012514592,
  147.431121676277,
  157.91915964734,
  120.327187813985,
  120.171500909668,
  142.428896431138,
  170.204819149259,
  158.214506571837,
  138.733887337899,
  123.808097192822,
  152.903612727097,
  140.769759838882,
  119.667303148295,
  129.45932949513,
  149.16660912873,
  150.254944400129,
  128.273570440445,
  153.278564595965,
  121.864809541511,
  161.542701435884,
  156.946807512148,
  158.881028441797,
  125.569197718704,
  137.940620707105,
  129.90603859364,
  152.757290961669,
  127.74685675702,
  137.183947553826,
  150.462794723172,
  140.929843406962,
  164.975435125998,
  182.344447097492,
  156.919848441299,
  150.750704137777,
  141.693725748231,
  110.676915031473,
  173.230786634625,
  148.248841760484,
  119.716751121111,
  143.947798672895,
  130.440507876832,
  131.964077776465,
  134.646898631033,
  100.778877180885,
  124.348744604234,
  127.180330688462,
  158.146060814462,
  171.358807156152,
  131.864294544909,
  143.085939246005,
  149.586050733912,
  132.2767374166,
  129.632802049237,
  137.02205312678,
  146.243848392623,
  125.713745961482,
  132.545872818923,
  129.355428299295,
  150.515004650034,
  149.481788573074,
  154.478850748833,
  137.329325663629,
  127.9504469047,
  133.350229974433]
idontgetoutmuch commented 7 years ago

I have narrowed down the error a bit

-- full simulation given parameters, returns the parameters
model :: (MonadBayes m, CustomReal m ~ Double) => (Double, Double) -> m (Double, Double)
model params = foldl step (initial_state params) obs >> return params where
  step old_state p_obs = do
    state <- old_state
    new_state <- transition params state
    error $ "You are here " ++ show state ++ " " ++ show p_obs
    observe (Continuous (Normal (log (p state)) 0.1)) (log p_obs)
    return new_state

gives

Main: You are here S {p = 96.41248458072532, z = 59.16931688468654, log_alpha = -3.3251317115519465} 76.2403679754159
CallStack (from HasCallStack):
  error, called at app/Main.hs:105:5 in main:Main

Moving the error below observe gives

Floating point exception: 8

So now we know something is going wrong inside observe

idontgetoutmuch commented 7 years ago

I have had a go at using dwarf and at using ghci -fexternal-interpreter -prof. The former throws lots of errors and I haven't investigated; the latter would mean re-installing the world with profiling turned on. Possibly the cheapest way to debug this is to add Show a everywhere and put in more trace statements. What do you think?

adscib commented 7 years ago

Looks like this is a problem with the model rather than inference. I made the following change:

model :: (MonadBayes m, CustomReal m ~ Double) => (Double, Double) -> m (Double, Double)
model params = foldl step (initial_state params) obs >> return params where
  step old_state p_obs = do
    state <- old_state
    new_state <- transition params state
    let m = log (p state)
    let x = log p_obs
    trace ("mean: " ++ show m ++ " x: " ++ show x) $ observe (Continuous (Normal m 0.1)) x
    return new_state

and mean is showing up as NaN, because p state is negative.

idontgetoutmuch commented 7 years ago

Thanks for this - I will investigate further. Intriguingly the model works in Stan and LibBI.

adscib commented 7 years ago

Most likely I made a mistake translating the model. The first thing to check would be the parametrisation of lognormal.

idontgetoutmuch commented 7 years ago

I will take a look tomorrow - thanks.

adscib commented 7 years ago

Turns out what I took for the solver running very slowly was a silly bug that caused an infinite loop when NaNs were added in LogDomain. This is fixed on master now. I can currently run the model without problems (after disabling floating point exceptions), but the mean of the observation model still comes up as NaN so that part still needs to be fixed to get any useful results.

idontgetoutmuch commented 7 years ago

Great! This makes debugging the model problem much easier (I hope).

idontgetoutmuch commented 7 years ago

I don't think it can be the model. I now have

-- full simulation given parameters, returns the parameters
model :: (MonadBayes m, CustomReal m ~ Double) => (Double, Double) -> m (Double, Double)
model params = foldl step (initial_state params) obs >> return params where
  step old_state p_obs = do
    state <- old_state
    new_state <- transition params state
    trace ("Old State: " ++ show state) $ return ()
    trace ("New state: " ++ show new_state) $ return ()
    let m = log (p state)
    let x = log p_obs
    observe (Continuous (Normal m 0.1)) x
    return new_state

and get

...
Old State: S {p = 68.03763906369906, z = 35.283787110815105, log_alpha = -0.9916988019989725}
New state: S {p = 64.90164005295357, z = 30.422294131721443, log_alpha = -0.12244882682932601}
Old State: S {p = 103.98419668715715, z = 51.46501997223569, log_alpha = NaN}
New state: S {p = NaN, z = 41.96850310814045, log_alpha = -2.6446936218209336}
...

Notice it is the model that is being passed the NaN; it does not create it. So maybe smc is the culprit somehow?

I am happy to debug further (well happy is putting it a bit strongly) but I am at a bit of a loss on how to proceed. Would it be possible to add Show instances somehow in monad-bayes? At the moment, because of the type family, CustomReal, it is not always possible to print values.

I will prod the ghc mailing list again to see if we can catch the floating point exception.

adscib commented 7 years ago

There's no way SMC is inserting the NaNs here, since it is fully polymorphic in the program's output type. Any value of state is ultimately sampled from the prior. However, SMC is interleaving execution of multiple instances of the program, so the old state you print does not match the new state. But I bet the very first NaN you see appears in the new state rather than old state.

For debugging I recommend you try to run sampleIO $ runWeighted $ parameters >>= model instead of SMC. This is doing simple importance sampling from the prior so if you see any NaNs there it is certainly a problem with the model. I suggest you run this multiple times printing the values of all intermediate variables in the model - this should let you pinpoint the problem. I will be very surprised if no NaNs show up. You will need to import Control.Monad.Bayes.Weighted.

Finally, if you need to print CustomReal you can take advantage of Real instance by show . toRational. But in the model you can constrain CustomReal m ~ Double so it shouldn't be a problem.

idontgetoutmuch commented 7 years ago

Ok I have implemented your first suggestion. I don't get NaNs in the state but I do in the likelihood(?).

main :: IO ()
main = do
  d <- sampleIO $ replicateM 1000 (runWeighted $ parameters >>= model)
  putStrLn $ show d
...,
((0.961757962405187,3.8030585526716627),NaN),
...

I will investigate further.

adscib commented 7 years ago

You don't get NaNs in the output, since the model outputs parameters. But you should be getting them in the intermediate states, in particular in log (p state) used as a mean of the observation model.

idontgetoutmuch commented 7 years ago

We had

sigma <- uniform 0 5

It should have been

sigma <- uniform 0 0.5

this gave the growth rate very large values which initially made the number of prey very large but then the model over-compensated making the number of prey negative leading to NaNs. Chains are now running 👍