User:Timothee Flutre/Notebook/Postdoc/2011/12/14
From OpenWetWare
m (→Learn about mixture models and the EM algorithm: minor refactoring) 
(→Learn about mixture models and the EM algorithm: clarify E step) 

(4 intermediate revisions not shown.)  
Line 46:  Line 46:  
* '''Missing data''': we introduce the following N [http://en.wikipedia.org/wiki/Latent_variable latent variables] <math>Z_1,...,Z_i,...,Z_N</math> (also called hidden or allocation variables), one for each observation, such that <math>Z_i=k</math> means that observation <math>x_i</math> belongs to cluster <math>k</math>. In fact, it is much easier to work the equations when defining each <math>Z_i</math> as a vector of length <math>K</math>, with <math>Z_{ik}=1</math> if observation <math>x_i</math> belongs to cluster <math>k</math>, and <math>Z_{ik}=0</math> otherwise ([http://en.wikipedia.org/wiki/Dummy_variable_%28statistics%29 indicator variables]). Thanks to this, we can reinterpret the mixture weights: <math>\forall i, P(Z_i=k\theta)=w_k</math>. Moreover, we can now define the membership probabilities, one for each observation:  * '''Missing data''': we introduce the following N [http://en.wikipedia.org/wiki/Latent_variable latent variables] <math>Z_1,...,Z_i,...,Z_N</math> (also called hidden or allocation variables), one for each observation, such that <math>Z_i=k</math> means that observation <math>x_i</math> belongs to cluster <math>k</math>. In fact, it is much easier to work the equations when defining each <math>Z_i</math> as a vector of length <math>K</math>, with <math>Z_{ik}=1</math> if observation <math>x_i</math> belongs to cluster <math>k</math>, and <math>Z_{ik}=0</math> otherwise ([http://en.wikipedia.org/wiki/Dummy_variable_%28statistics%29 indicator variables]). Thanks to this, we can reinterpret the mixture weights: <math>\forall i, P(Z_i=k\theta)=w_k</math>. Moreover, we can now define the membership probabilities, one for each observation:  
  <math>p(ki) = P(  +  <math>p(ki) = P(Z_{ik}=1x_i,\theta) = \frac{w_k \phi(x_i\mu_k,\sigma_k)}{\sum_{l=1}^K w_l \phi(x_i\mu_l,\sigma_l)}</math> 
The observeddata likelihood (also called sometimes "incomplete" or "marginal", even though these appellations are misnomers) is still written the same way:  The observeddata likelihood (also called sometimes "incomplete" or "marginal", even though these appellations are misnomers) is still written the same way:  
Line 54:  Line 54:  
But now we can also write the augmenteddata likelihood, assuming all observations are independent conditionally on their membership:  But now we can also write the augmenteddata likelihood, assuming all observations are independent conditionally on their membership:  
  <math>L_{aug}(\theta) = P(X,Z\theta) = \prod_{i=1}^N P(x_i  +  <math>L_{aug}(\theta) = P(X,Z\theta) = \prod_{i=1}^N P(x_iZ_i,\theta) P(Z_i\theta) = \prod_{i=1}^N \left( \prod_{k=1}^K \phi(x_i\mu_k,\sigma_k)^{Z_{ik}} w_k^{Z_{ik}} \right)</math>. 
+  
+  And here is the augmenteddata loglikelihood (useful in the M step of the EM algorithm, see below):  
+  
+  <math>l_{aug}(\theta) = \sum_{i=1}^N \left( \sum_{k=1}^K Z_{ik} ln(\phi(x_i\mu_k,\sigma_k)) + \sum_{k=1}^K Z_{ik} ln(w_k) \right)</math>  
+  
+  In terms of [http://en.wikipedia.org/wiki/Graphical_model graphical model], the Gaussian mixture model described here can be represented like [http://en.wikipedia.org/wiki/File:Nonbayesiangaussianmixture.svg this].  
Line 112:  Line 118:  
  * '''  +  * '''Formulas of both steps''': in both steps we need to use <math>Q</math>, whether to evaluate it or maximize it. 
+  
+  <math>Q(\thetaX,\theta^{(t)}) = \mathbb{E}_{ZX,\theta^{(t)}} \left[ ln(P(X,Z\theta))X,\theta^{(t)} \right]</math>  
+  
+  <math>Q(\thetaX,\theta^{(t)}) = \mathbb{E}_{ZX,\theta^{(t)}} \left[ l_{aug}(\theta)X,\theta^{(t)} \right]</math>  
+  
+  <math>Q(\thetaX,\theta^{(t)}) = \sum_{i=1}^N \left( \sum_{k=1}^K \mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] ln(\phi(x_i\mu_k,\sigma_k)) + \sum_{k=1}^K \mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] ln(w_k) \right)</math>  
+  
+  
+  * '''Formulas of the E step''': as indicated above, the E step consists in evaluating <math>Q</math>, i.e. simply evaluating the conditional expectation over the latent variables of the augmenteddata loglikelihood given the observed data and the current estimates of the parameters.  
+  
+  <math>\mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] = P(Z_{ik}=1x_i,\theta_k^{(t)}) = \frac{w_k^{(t)} \phi(x_i\mu_k^{(t)},\sigma_k^{(t)})}{\sum_{l=1}^K w_l^{(t)} \phi(x_i\mu_l^{(t)},\sigma_l^{(t)})} = p(ki)</math>  
+  
+  
+  * '''Formulas of the M step''': in this step, we need to maximize <math>Q</math> (also written <math>\mathcal{F}</math> above), w.r.t. each <math>\theta_k</math>. A few important rules are required to write down the analytical formulas of the MLEs, but only from a highschool level (see [http://en.wikipedia.org/wiki/Differentiation_%28mathematics%29#Rules_for_finding_the_derivative here]).  
+  
+  
+  * '''M step  weights''': let's start by finding the maximumlikelihood estimates of the weights <math>w_k</math>. But remember the constraint <math>\sum_{k=1}^K w_k = 1</math>. To enforce it, we can use a [http://en.wikipedia.org/wiki/Lagrange_multiplier Lagrange multiplier], <math>\lambda</math>. This means that we now need to maximize the following equation where <math>\Lambda</math> is a Lagrange function (only the part of Q being a function of the weights is kept):  
+  
+  <math>\Lambda(w_k,\lambda) = \sum_{i=1}^N \left( \sum_{k=1}^K p(ki) ln(w_k) \right) + \lambda (1  \sum_{k=1}^K w_k)</math>  
+  
+  As usual, to find the maximum, we derive and equal to zero:  
+  
+  <math>\frac{\Lambda}{\partial w_k} = \sum_{i=1}^N \left( p(ki) \frac{1}{\hat{w}_k^{(t+1)}} \right)  \lambda = 0</math>  
+  
+  <math>\hat{w}_k^{(t+1)} = \frac{1}{\lambda} \sum_{i=1}^N p(ki)</math>  
+  
+  Now, to find the multiplier, we go back to the constraint:  
  <math>\  +  <math>\sum_{k=1}^K \hat{w}_k^{(t+1)} = 1 \rightarrow \lambda = \sum_{i=1}^N \sum_{k=1}^K p(ki) = N</math> 
  +  Finally:  
  <math>\  +  <math>\hat{w}_k^{(t+1)} = \frac{1}{N} \sum_{i=1}^N p(ki)</math> 
  
  +  * '''M step  means''':  
  +  <math>\frac{\partial Q}{\partial \mu_k} = \sum_{i=1}^N p(ki) \frac{\partial ln(\phi(x_i\mu_k,\sigma_k))}{\partial \mu_k}</math>  
  <math>\frac{\partial  +  <math>\frac{\partial Q}{\partial \mu_k} = \sum_{i=1}^N p(ki) \frac{1}{\phi(x_i\mu_k,\sigma_k)} \frac{\partial \phi(x_i\mu_k,\sigma_k)}{\partial \mu_k}</math> 
  +  <math>\frac{\partial Q}{\partial \mu_k} = 0 = \sum_{i=1}^N p(ki) (x_i  \hat{\mu}_k^{(t+1)})</math>  
  +  Finally:  
  +  <math>\hat{\mu}_k^{(t+1)} = \frac{\sum_{i=1}^N p(k/i) x_i}{\sum_{i=1}^N p(k/i)}</math>  
  
  +  * '''M step  variances''': same kind of algebra  
  <math>\frac{\partial  +  <math>\frac{\partial Q}{\partial \sigma_k} = \sum_{i=1}^N p(k/i) (\frac{1}{\sigma_k} + \frac{(x_i  \mu_k)^2}{\sigma_k^3})</math> 
  +  <math>\hat{\sigma}_k^{(t+1)} = \sqrt{\frac{\sum_{i=1}^N p(k/i) (x_i  \hat{\mu}_k^{(t+1)})^2}{\sum_{i=1}^N p(k/i)}}</math>  
  
  +  * '''M step  weights (2)''': we can write them in terms of unconstrained variables <math>\gamma_k</math> ([http://en.wikipedia.org/wiki/Softmax_activation_function softmax function]):  
<math>w_k = \frac{e^{\gamma_k}}{\sum_{k=1}^K e^{\gamma_k}}</math>  <math>w_k = \frac{e^{\gamma_k}}{\sum_{k=1}^K e^{\gamma_k}}</math>  
Line 156:  Line 186:  
<math>\frac{\partial l(\theta)}{\partial w_k} = \sum_{i=1}^N (p(k/i)  w_k)</math>  <math>\frac{\partial l(\theta)}{\partial w_k} = \sum_{i=1}^N (p(k/i)  w_k)</math>  
  Finally  +  Finally: 
<math>\hat{w}_k = \frac{1}{N} \sum_{i=1}^N p(k/i)</math>  <math>\hat{w}_k = \frac{1}{N} \sum_{i=1}^N p(k/i)</math>  
Line 273:  Line 303:  
K < ncol(membership.probas)  K < ncol(membership.probas)  
sapply(1:K, function(k){  sapply(1:K, function(k){  
  sqrt(sum(unlist(Map(function(  +  sqrt(sum(unlist(Map(function(p.ki, x.i){ 
  +  p.ki * (x.i  means[k])^2  
}, membership.probas[,k], data))) /  }, membership.probas[,k], data))) /  
sum.membership.probas[k])  sum.membership.probas[k])  
Line 365:  Line 395:  
ds < lapply(1:K, function(k){dnorm(x=rx, mean=res$params$mus[k], sd=res$params$sigmas[k])})  ds < lapply(1:K, function(k){dnorm(x=rx, mean=res$params$mus[k], sd=res$params$sigmas[k])})  
f < sapply(1:length(rx), function(i){  f < sapply(1:length(rx), function(i){  
  res$params$mix.weights[1] * ds[[1]][i] + res$params$mix.weights[2] * ds[[2]][i] + res$params$mix.weights[3] *  +  res$params$mix.weights[1] * ds[[1]][i] + res$params$mix.weights[2] * ds[[2]][i] + res$params$mix.weights[3] * ds[[3]][i] 
})  })  
lines(rx, f, col="red", lwd=2)  lines(rx, f, col="red", lwd=2)  
Line 384:  Line 414:  
* '''References''':  * '''References''':  
  ** chapter 1 from the PhD thesis of Matthew Stephens (Oxford, 2000)  +  ** chapter 1 from the PhD thesis of Matthew Stephens (Oxford, 2000) freely available [http://www.stat.washington.edu/stephens/papers/tabstract.html online] 
  ** chapter 2 from the PhD thesis of Matthew Beal (UCL, 2003)  +  ** chapter 2 from the PhD thesis of Matthew Beal (UCL, 2003) freely available [http://www.cse.buffalo.edu/faculty/mbeal/thesis/ online] 
  ** lecture "Mixture Models, Latent Variables and the EM Algorithm" from Cosma Shalizi  +  ** lecture "Mixture Models, Latent Variables and the EM Algorithm" from Cosma Shalizi freely available [http://www.stat.cmu.edu/~cshalizi/uADA/12/ online] 
** book "Introducing Monte Carlo Methods with R" from Robert and and Casella (2009)  ** book "Introducing Monte Carlo Methods with R" from Robert and and Casella (2009)  
Revision as of 17:33, 8 May 2013
Project name  Main project page Previous entry Next entry 
Learn about mixture models and the EM algorithm(Caution, this is my own quickanddirty tutorial, see the references at the end for presentations by professional statisticians.)
The constraints are: and
As usual, it's easier to deal with the loglikelihood:
Let's take the derivative with respect to one parameter, eg. θ_{l}:
This shows that maximizing the likelihood of a mixture model is like doing a weighted likelihood maximization. However, these weights depend on the parameters we want to estimate! That's why we now switch to the missingdata formulation of the mixture model.
The observeddata likelihood (also called sometimes "incomplete" or "marginal", even though these appellations are misnomers) is still written the same way:
But now we can also write the augmenteddata likelihood, assuming all observations are independent conditionally on their membership: . And here is the augmenteddata loglikelihood (useful in the M step of the EM algorithm, see below):
In terms of graphical model, the Gaussian mixture model described here can be represented like this.
Here is the observeddata loglikelihood:
First we introduce the hidden variables by integrating them out:
Then, we use any probability distribution q on these hidden variables (in fact, we use a distinct distribution for each observation):
And here is the great trick, as explained by Beal: "any probability distribution over the hidden variables gives rise to a lower bound on l_{obs}". This is due to to the Jensen inequality (the logarithm is concave):
At each iteration, the E step maximizes the lower bound () with respect to the :
The Estep amounts to inferring the posterior distribution of the hidden variables given the current parameter θ^{(t)}:
Indeed, the make the bound tight (the inequality becomes an equality):
Then, at the M step, we use these statistics to maximize the new lower bound with respect to θ, and therefore find θ^{(t + 1)}.
As a result, the E step may not always lead to a tight bound.
As usual, to find the maximum, we derive and equal to zero:
Now, to find the multiplier, we go back to the constraint:
Finally:
Finally:
Finally:
#' Generate univariate observations from a mixture of Normals #' #' @param K number of components #' @param N number of observations #' @param gap difference between all component means GetUnivariateSimulatedData < function(K=2, N=100, gap=6){ mus < seq(0, gap*(K1), gap) sigmas < runif(n=K, min=0.5, max=1.5) tmp < floor(rnorm(n=K1, mean=floor(N/K), sd=5)) ns < c(tmp, N  sum(tmp)) clusters < as.factor(matrix(unlist(lapply(1:K, function(k){rep(k, ns[k])})), ncol=1)) obs < matrix(unlist(lapply(1:K, function(k){ rnorm(n=ns[k], mean=mus[k], sd=sigmas[k]) }))) new.order < sample(1:N, N) obs < obs[new.order] rownames(obs) < NULL clusters < clusters[new.order] return(list(obs=obs, clusters=clusters, mus=mus, sigmas=sigmas, mix.weights=ns/N)) }
#' Return probas of latent variables given data and parameters from previous iteration #' #' @param data Nx1 vector of observations #' @param params list which components are mus, sigmas and mix.weights Estep < function(data, params){ GetMembershipProbas(data, params$mus, params$sigmas, params$mix.weights) } #' Return the membership probabilities P(zi=k/xi,theta) #' #' @param data Nx1 vector of observations #' @param mus Kx1 vector of means #' @param sigmas Kx1 vector of std deviations #' @param mix.weights Kx1 vector of mixture weights w_k=P(zi=k/theta) #' @return NxK matrix of membership probas GetMembershipProbas < function(data, mus, sigmas, mix.weights){ N < length(data) K < length(mus) tmp < matrix(unlist(lapply(1:N, function(i){ x < data[i] norm.const < sum(unlist(Map(function(mu, sigma, mix.weight){ mix.weight * GetUnivariateNormalDensity(x, mu, sigma)}, mus, sigmas, mix.weights))) unlist(Map(function(mu, sigma, mix.weight){ mix.weight * GetUnivariateNormalDensity(x, mu, sigma) / norm.const }, mus[K], sigmas[K], mix.weights[K])) })), ncol=K1, byrow=TRUE) membership.probas < cbind(tmp, apply(tmp, 1, function(x){1  sum(x)})) names(membership.probas) < NULL return(membership.probas) } #' Univariate Normal density GetUnivariateNormalDensity < function(x, mu, sigma){ return( 1/(sigma * sqrt(2*pi)) * exp(1/(2*sigma^2)*(xmu)^2) ) }
#' Return ML estimates of parameters #' #' @param data Nx1 vector of observations #' @param params list which components are mus, sigmas and mix.weights #' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta) Mstep < function(data, params, membership.probas){ params.new < list() sum.membership.probas < apply(membership.probas, 2, sum) params.new$mus < GetMlEstimMeans(data, membership.probas, sum.membership.probas) params.new$sigmas < GetMlEstimStdDevs(data, params.new$mus, membership.probas, sum.membership.probas) params.new$mix.weights < GetMlEstimMixWeights(data, membership.probas, sum.membership.probas) return(params.new) } #' Return ML estimates of the means (1 per cluster) #' #' @param data Nx1 vector of observations #' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta) #' @param sum.membership.probas Kx1 vector of sum per column of matrix above #' @return Kx1 vector of means GetMlEstimMeans < function(data, membership.probas, sum.membership.probas){ K < ncol(membership.probas) sapply(1:K, function(k){ sum(unlist(Map("*", membership.probas[,k], data))) / sum.membership.probas[k] }) } #' Return ML estimates of the std deviations (1 per cluster) #' #' @param data Nx1 vector of observations #' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta) #' @param sum.membership.probas Kx1 vector of sum per column of matrix above #' @return Kx1 vector of std deviations GetMlEstimStdDevs < function(data, means, membership.probas, sum.membership.probas){ K < ncol(membership.probas) sapply(1:K, function(k){ sqrt(sum(unlist(Map(function(p.ki, x.i){ p.ki * (x.i  means[k])^2 }, membership.probas[,k], data))) / sum.membership.probas[k]) }) } #' Return ML estimates of the mixture weights #' #' @param data Nx1 vector of observations #' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta) #' @param sum.membership.probas Kx1 vector of sum per column of matrix above #' @return Kx1 vector of mixture weights GetMlEstimMixWeights < function(data, membership.probas, sum.membership.probas){ K < ncol(membership.probas) sapply(1:K, function(k){ 1/length(data) * sum.membership.probas[k] }) }
GetLogLikelihood < function(data, mus, sigmas, mix.weights){ loglik < sum(sapply(data, function(x){ log(sum(unlist(Map(function(mu, sigma, mix.weight){ mix.weight * GetUnivariateNormalDensity(x, mu, sigma) }, mus, sigmas, mix.weights)))) })) return(loglik) }
EMalgo < function(data, params, threshold.convergence=10^(2), nb.iter=10, verbose=1){ logliks < vector() i < 1 if(verbose > 0) cat(paste("iter ", i, "\n", sep="")) membership.probas < Estep(data, params) params < Mstep(data, params, membership.probas) loglik < GetLogLikelihood(data, params$mus, params$sigmas, params$mix.weights) logliks < append(logliks, loglik) while(i < nb.iter){ i < i + 1 if(verbose > 0) cat(paste("iter ", i, "\n", sep="")) membership.probas < Estep(data, params) params < Mstep(data, params, membership.probas) loglik < GetLogLikelihood(data, params$mus, params$sigmas, params$mix.weights) if(loglik < logliks[length(logliks)]){ msg < paste("the loglikelihood is decreasing:", loglik, "<", logliks[length(logliks)]) stop(msg, call.=FALSE) } logliks < append(logliks, loglik) if(abs(logliks[i]  logliks[i1]) <= threshold.convergence) break } return(list(params=params, membership.probas=membership.probas, logliks=logliks, nb.iters=i)) }
## simulate data K < 3 N < 300 simul < GetUnivariateSimulatedData(K, N) data < simul$obs ## run the EM algorithm params0 < list(mus=runif(n=K, min=min(data), max=max(data)), sigmas=rep(1, K), mix.weights=rep(1/K, K)) res < EMalgo(data, params0, 10^(3), 1000, 1) ## check its convergence plot(res$logliks, xlab="iterations", ylab="loglikelihood", main="Convergence of the EM algorithm", type="b") ## plot the data along with the inferred densities png("mixture_univar_em.png") hist(data, breaks=30, freq=FALSE, col="grey", border="white", ylim=c(0,0.15), main="Histogram of data overlaid with densities inferred by EM") rx < seq(from=min(data), to=max(data), by=0.1) ds < lapply(1:K, function(k){dnorm(x=rx, mean=res$params$mus[k], sd=res$params$sigmas[k])}) f < sapply(1:length(rx), function(i){ res$params$mix.weights[1] * ds[[1]][i] + res$params$mix.weights[2] * ds[[2]][i] + res$params$mix.weights[3] * ds[[3]][i] }) lines(rx, f, col="red", lwd=2) dev.off() It seems to work well, which was expected as the clusters are well separated from each other... The classification of each observation can be obtained via the following command: ## get the classification of the observations memberships < apply(res$membership.probas, 1, function(x){which(x > 0.5)}) table(memberships)
