probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

Fixing the value range that learnable parameters can take when being optimised #499

Closed min-nguyen closed 1 year ago

min-nguyen commented 1 year ago

Hi there!

Using black_box_vi, I'm trying to optimise the parameters to some Beta distributions in a Hidden Markov Model:

using Gen

@gen function hmmGuide(T::Int)
  @param trans_p_a
  @param trans_p_b
  @param obs_p_a
  @param obs_p_b
  trans_p = @trace(beta(trans_p_a, trans_p_b), :trans_p)
  obs_p   = @trace(beta(obs_p_a, obs_p_b), :obs_p)
end

@gen function hmm(T::Int)
  trans_p = @trace(beta(2, 2), :trans_p)
  obs_p   = @trace(beta(2, 2), :obs_p)
  x  = 0::Int
  ys = Array{Int}(undef, T)
  for t=1:T
    dX    = @trace(bernoulli(trans_p), (:x, t))
    x     = x + Int(dX)
    ys[t] = @trace(binom(x, obs_p), (:y, t))
  end
  return ys
end

function bbviHMM(n_iters::Int, n_samples::Int, T::Int)
  # Create a set of constraints fixing the
  # y coordinates to the observed y values
  ys = hmm(T)
  constraints = choicemap()
  for (i, y) in enumerate(ys)
    constraints[(:y, i)] = y
  end

  init_param!(hmmGuide, :trans_p_a, 0.5)
  init_param!(hmmGuide, :trans_p_b, 0.5)
  init_param!(hmmGuide, :obs_p_a, 0.5)
  init_param!(hmmGuide, :obs_p_b, 0.5)

  update = ParamUpdate(GradientDescent(1e-15, 1000000000), hmmGuide)
  black_box_vi!(hmm, (T,), constraints, hmmGuide, (T,), update;
    iters=n_iters, samples_per_iter=n_samples, verbose=true)
end

bbviHMM(500, 500, 10)

But I'm running into some NaN errors; I'm thinking that my parameters are exiting the range (0, 1) during optimisation:

ERROR: LoadError: DomainError with NaN:
Beta: the condition α > zero(α) is not satisfied.

I've tried setting the step size to be very small. Is there a way to prevent this from happening, such as enforcing an inequality constraint on the parameters?

Thanks loads!

ztangent commented 1 year ago

Hi @min-nguyen! I don't think there's an easy way to enforce constraints directly -- that would probably require a constrained optimization / gradient descent algorithm that we don't currently have implemented.

Have you considered parameterizing your guide distribution in terms of the log of the parameter you care about instead? For example:

@gen function hmmGuide(T::Int)
  @param log_trans_p_a
  @param log_trans_p_b
  @param log_obs_p_a
  @param log_obs_p_b
  trans_p = @trace(beta(exp(log_trans_p_a), exp(log_trans_p_b)), :trans_p)
  obs_p   = @trace(beta(exp(log_obs_p_a), exp(log_obs_p_b)), :obs_p)
end

I'm pretty sure this should avoid the issues you're running into, though I can't be certain without trying myself!

min-nguyen commented 1 year ago

Ah that makes perfect sense, thanks! Unfortunately, I still run into the same error; I'm very confused.

I've tried using the same exp pattern in the model itself, as well as changing the init_params to larger floats.

I think this is originating from the objective becoming negative infinity. est objective: -Inf

min-nguyen commented 1 year ago

I think the model itself is a bit problematic, e.g. an observation of 5 may be provided to a distribution like binomial(4, 0.5), where the log-density would be -Infinity. I've managed to "resolve" this by using a normal distribution rather than a Binomial one. I wonder if there's a way to not have to do this; anyway, I really appreciate the help, thanks again.

ztangent commented 1 year ago

Ah yes, if your observations have support over all the positive integers and zero, then binomial is not a great choice for the observation noise distribution! Doing inference assuming a binomial would require some solving a non-trivial constraint satisfaction problem, e.g. given that the observation was 5 at timestep 6, what possible sequence of dXs could have been sampled, such that the problem of observing 5 was non-zero?

Black-box VI can't automatically solve that problem for you -- though I believe traditional forward-backward HMM algorithms should be able to, because they enumerate over all possibilities!

ztangent commented 1 year ago

Incidentally, there is some specialized code in this Gen extension library that I believe supports the forward-backward algorithm for HMMs --- I believe it assumes that the model has no continuous variables, but you may be able to use it within a larger model to estimate the parameters that you're interested in!

https://github.com/probcomp/GenVariableElimination.jl