Closed ckemere closed 8 years ago
Hi! You need to initialize some of the parameters randomly in order to break the symmetry. Otherwise all mixture components are identical. I think this is the problem you're facing. So initialize the mean randomly:
mu_est.initialize_from_value(np.random.randn(K2, D))
Also, I'd give the variance parameter a better initialization:
Lambda_est.initialize_from_value(np.identity(D))
In order to update Z
before mu_est
, use:
Q = VB(Y, Z, mu_est, Lambda_est, A, a0)
Also, I noticed that you are not having mixture plates for Lambda_est
. Maybe that's what you want, but if you want different variance for each mixture cluster, use
Lambda_est = Wishart(D, 1e-5*np.identity(D),plates=(K2,))
Here is a complete working example:
K2 = 10
D=2
# Just some dummy data now
N = 100
y = np.concatenate([np.random.randn(50,2), 2+np.random.randn(50,2)], axis=0)
from bayespy.nodes import Dirichlet
a0 = Dirichlet(1e-3*np.ones(K2))
A = Dirichlet(1e-3*np.ones((K2,K2)))
from bayespy.nodes import CategoricalMarkovChain
Z = CategoricalMarkovChain(a0, A, states=N)
from bayespy.nodes import Gaussian, Wishart
mu_est = Gaussian(np.zeros(D), 1e-5*np.identity(D),plates=(K2,))
Lambda_est = Wishart(D, 1e-5*np.identity(D),plates=(K2,)) # <- different variance for each cluster?
from bayespy.nodes import Mixture
Y = Mixture(Z, Gaussian, mu_est, Lambda_est)
Y.observe(y)
# Random initialization to break the symmetry
mu_est.initialize_from_value(np.random.randn(K2, D))
# Reasonable initialization for Lambda
Lambda_est.initialize_from_value(np.identity(D))
from bayespy.inference import VB
Q = VB(Y, Z, mu_est, Lambda_est, A, a0)
Q.update(repeat=1000)
from bayespy import plot as bpplt
bpplt.hinton(Z)
If you want to make the learning less sensitive to the initialization, you can try using deterministic annealing: http://www.bayespy.org/user_guide/advanced.html#deterministic-annealing
I hope this helps (and works). Please don't hesitate to ask further questions on this.
Beautiful. Thanks much! I was wondering how the initialization step was being taken care of!
Edit For a classical Gaussian-emission HMM, there should obviously be plates for Lambda. I was thinking this would imply covariance between the mixture elements, but that was just my brain being confused. Thanks again for that.
It might be worth adding this to your tutorial? The data initialization is here:
# simulated data
mu = np.array([ [0,0], [3,4], [6,0] ])
D = 2
std = 2.0
K = 3 # number of "states"
N = 200 # number of samples
p0 = np.ones(K) / K
q = 0.9 # self probability
r = (1-q) / (K -1)
P = q*np.identity(K) + r * (np.ones((3,3)) - np.identity(3)) # transition probability matrix
#run simulation
y = np.zeros((N,2))
z = np.zeros(N)
state = np.random.choice(K, p=p0)
for n in range(N) :
z[n] = state
y[n,:] = std * np.random.randn(2) + mu[state]
state = np.random.choice(K, p=P[state])
After I fix this issues: https://github.com/bayespy/bayespy/issues/30 you could initialize Z randomly with Z.initialize_from_random()
, and then update mean and variance before Z. Would be better that way, in my opinion. Also, some other steps can be used to improve the accuracy of the VB approximation and to reduce the sensitivity to the initialization but it will make things a little bit more complex. First, as I mentioned, you could use deterministic annealing. Second, you could use GaussianWishart or GaussianGammaARD nodes to model the mean and variance in a single node. If you want diagonal variance, you can use Gamma and GaussianARD. I can give more details another time.
The annealing seems to make things a bit more stable, but interestingly training with two sequences makes the estimation much more stable. Presumably this is because a0 - the initial state distribution is arbitrary in the one-sequence case, but has evidence in the two-sequence case (resulting in un-broken symmetries).
@jluttine how does one convert the transition_probabilities matrix of the HMM (A
, in your above example) into a valid transition kernel? I'm not actually sure what the matrix contains - Dirichlet concentration parameters?
@RylanSchaeffer A
is a matrix of transition probabilities. Each row corresponds to a probability vector that sums to one. These vectors are given Dirichlet prior which is a distribution for probability vectors. The node A
corresponds to the unknown matrix of transition probabilities, so it's first given a Dirichlet prior and then after fitting it'll contain the approximate posterior distribution of the transition probabilities (which is also a Dirichlet distribution).
Does bp.nodes.Categorical(A).get_moments()[0]
give you what you want? It gives the posterior probability of state transitions.
Yes, bp.nodes.Categorical(A).get_moments()[0]
is exactly what I was looking for. I thought I should be able to extract the transition probabilities from the posterior A directly.
@RylanSchaeffer Yeah, you can get it from A
too. You can extract the parameters of the Dirichlet distribution and do whatever you want with those. Or you can get the posterior moments of the variable, but those contain only <log(A)>
. So, this was just one a bit "hackish" way of getting a normalized version of the posterior mean.
@jluttine how does one extract the posterior over latent states from CategoricalMarkovChain
? The object has an array with shape (length of chain, number of latent variables, number of latent variables)
. I don't see a clear answer in the documentation (https://www.bayespy.org/user_api/generated/generated/bayespy.nodes.CategoricalMarkovChain.html) or the example (https://www.bayespy.org/examples/hmm.html).
Edit: This seems so simple. I must be missing something obvious.
In general, it feels like most of the tutorials stop after inferring the posterior for parameters but don't explain how to retrieve it. This is true of the multinomial tutorial (https://www.bayespy.org/examples/multinomial.html), the GMM tutorial (https://www.bayespy.org/examples/gmm.html), I think.
@jluttine I just realized you're the author of tikz-bayesnet!! Wow!!!
@RylanSchaeffer The posterior distribution is represented by sufficient statistics (moments) and natural parameters. You can get moments with get_moments()
method for any node. For a random variable (i.e., not deterministic node), you can get the natural parameters via phi
attribute, if I remember correctly. So, phi
contains the parameters of the posterior distribution. These two representations have the same shape. For CategoricalMarkovChain
, both are lists of two elements. The first element corresponds to the initial state and the second element is a large array containing all transitions. In order to interpret and use these correctly, you need to understand the exact probability distribution definition the node is using. Unfortunately, it seems that the documentation doesn't necessarily cover these formulations of the distributions...
But if you're happy with, for instance, the posterior expected values of the hidden states, those you can get easily. The first element in the moments list is <z_0>
. Then, the second element is an array which contains all expectations of the form <z_{n} z_{n+1}^T>
. So, for that array, you just need to marginalize (=sum) over the second last axis and you'll get expectations of z_n
for n=1,...
. (I'm not 100% sure, I'm just writing out of my memory, I'm not checking anything at the moment.)
And yes, the documentation definitely could add more information. Pull requests for such improvements are of course most welcome. But what would you want to extract exactly in this case? That is, what kind of representation of the posterior are you expecting or hoping to get? Some non-natural but widely used / standard parameterizations will probably require some manual conversions. Those could be of course implemented node-by-node as needed, but I haven't done that comprehensively.
Yeah, I'm an author of tikz-bayesnet together with Laura Dietz. Glad if the package has been useful for you!
I wanted to modify the example HMM to additionally estimate the mean and variance of the observations. Oddly, I find that there's rapid convergence to an uninformative model. Any thoughts?
Here's my code (which follows the data generation from the example):