Difference between revisions of "User:Timothee Flutre/Notebook/Postdoc/2011/12/14"

From OpenWetWare
Jump to: navigation, search
m (Learn about mixture models and the EM algorithm: minor refactoring)
(Learn about mixture models and the EM algorithm: clarify E step)
(4 intermediate revisions by the same user not shown)
Line 46: Line 46:
 
* '''Missing data''': we introduce the following N [http://en.wikipedia.org/wiki/Latent_variable latent variables] <math>Z_1,...,Z_i,...,Z_N</math> (also called hidden or allocation variables), one for each observation, such that <math>Z_i=k</math> means that observation <math>x_i</math> belongs to cluster <math>k</math>. In fact, it is much easier to work the equations when defining each <math>Z_i</math> as a vector of length <math>K</math>, with <math>Z_{ik}=1</math> if observation <math>x_i</math> belongs to cluster <math>k</math>, and <math>Z_{ik}=0</math> otherwise ([http://en.wikipedia.org/wiki/Dummy_variable_%28statistics%29 indicator variables]). Thanks to this, we can reinterpret the mixture weights: <math>\forall i, P(Z_i=k|\theta)=w_k</math>. Moreover, we can now define the membership probabilities, one for each observation:
 
* '''Missing data''': we introduce the following N [http://en.wikipedia.org/wiki/Latent_variable latent variables] <math>Z_1,...,Z_i,...,Z_N</math> (also called hidden or allocation variables), one for each observation, such that <math>Z_i=k</math> means that observation <math>x_i</math> belongs to cluster <math>k</math>. In fact, it is much easier to work the equations when defining each <math>Z_i</math> as a vector of length <math>K</math>, with <math>Z_{ik}=1</math> if observation <math>x_i</math> belongs to cluster <math>k</math>, and <math>Z_{ik}=0</math> otherwise ([http://en.wikipedia.org/wiki/Dummy_variable_%28statistics%29 indicator variables]). Thanks to this, we can reinterpret the mixture weights: <math>\forall i, P(Z_i=k|\theta)=w_k</math>. Moreover, we can now define the membership probabilities, one for each observation:
  
<math>p(k|i) = P(z_{ik}=1|x_i,\theta) = \frac{w_k \phi(x_i|\mu_k,\sigma_k)}{\sum_{l=1}^K w_l \phi(x_i|\mu_l,\sigma_l)}</math>
+
<math>p(k|i) = P(Z_{ik}=1|x_i,\theta) = \frac{w_k \phi(x_i|\mu_k,\sigma_k)}{\sum_{l=1}^K w_l \phi(x_i|\mu_l,\sigma_l)}</math>
  
 
The observed-data likelihood (also called sometimes "incomplete" or "marginal", even though these appellations are misnomers) is still written the same way:
 
The observed-data likelihood (also called sometimes "incomplete" or "marginal", even though these appellations are misnomers) is still written the same way:
Line 54: Line 54:
 
But now we can also write the augmented-data likelihood, assuming all observations are independent conditionally on their membership:
 
But now we can also write the augmented-data likelihood, assuming all observations are independent conditionally on their membership:
  
<math>L_{aug}(\theta) = P(X,Z|\theta) = \prod_{i=1}^N P(x_i|z_i,\theta) P(z_i|\theta) = \prod_{i=1}^N \left( \prod_{k=1}^K \phi(x_i|\mu_k,\sigma_k)^{z_{ik}} w_k^{z_{ik}} \right)</math>.
+
<math>L_{aug}(\theta) = P(X,Z|\theta) = \prod_{i=1}^N P(x_i|Z_i,\theta) P(Z_i|\theta) = \prod_{i=1}^N \left( \prod_{k=1}^K \phi(x_i|\mu_k,\sigma_k)^{Z_{ik}} w_k^{Z_{ik}} \right)</math>.
 +
 
 +
And here is the augmented-data log-likelihood (useful in the M step of the EM algorithm, see below):
 +
 
 +
<math>l_{aug}(\theta) = \sum_{i=1}^N \left( \sum_{k=1}^K Z_{ik} ln(\phi(x_i|\mu_k,\sigma_k)) + \sum_{k=1}^K Z_{ik} ln(w_k) \right)</math>
 +
 
 +
In terms of [http://en.wikipedia.org/wiki/Graphical_model graphical model], the Gaussian mixture model described here can be represented like [http://en.wikipedia.org/wiki/File:Nonbayesian-gaussian-mixture.svg this].
  
  
Line 112: Line 118:
  
  
* '''ML formulas of the M step''': a few important rules are required to write down the analytical formulae of the MLEs, but only from a high-school level (see [http://en.wikipedia.org/wiki/Differentiation_%28mathematics%29#Rules_for_finding_the_derivative here]). Let's start by finding the maximum-likelihood estimates of the mean of each cluster:
+
* '''Formulas of both steps''': in both steps we need to use <math>Q</math>, whether to evaluate it or maximize it.
 +
 
 +
<math>Q(\theta|X,\theta^{(t)}) = \mathbb{E}_{Z|X,\theta^{(t)}} \left[ ln(P(X,Z|\theta))|X,\theta^{(t)} \right]</math>
 +
 
 +
<math>Q(\theta|X,\theta^{(t)}) = \mathbb{E}_{Z|X,\theta^{(t)}} \left[ l_{aug}(\theta)|X,\theta^{(t)} \right]</math>
 +
 
 +
<math>Q(\theta|X,\theta^{(t)}) = \sum_{i=1}^N \left( \sum_{k=1}^K \mathbb{E}_{Z|X,\theta^{(t)}}[Z_{ik}|x_i,\theta_k^{(t)}] ln(\phi(x_i|\mu_k,\sigma_k)) + \sum_{k=1}^K \mathbb{E}_{Z|X,\theta^{(t)}}[Z_{ik}|x_i,\theta_k^{(t)}] ln(w_k) \right)</math>
 +
 
 +
 
 +
* '''Formulas of the E step''': as indicated above, the E step consists in evaluating <math>Q</math>, i.e. simply evaluating the conditional expectation over the latent variables of the augmented-data log-likelihood given the observed data and the current estimates of the parameters.
 +
 
 +
<math>\mathbb{E}_{Z|X,\theta^{(t)}}[Z_{ik}|x_i,\theta_k^{(t)}] = P(Z_{ik}=1|x_i,\theta_k^{(t)}) = \frac{w_k^{(t)} \phi(x_i|\mu_k^{(t)},\sigma_k^{(t)})}{\sum_{l=1}^K w_l^{(t)} \phi(x_i|\mu_l^{(t)},\sigma_l^{(t)})} = p(k|i)</math>
 +
 
 +
 
 +
* '''Formulas of the M step''': in this step, we need to maximize <math>Q</math> (also written <math>\mathcal{F}</math> above), w.r.t. each <math>\theta_k</math>. A few important rules are required to write down the analytical formulas of the MLEs, but only from a high-school level (see [http://en.wikipedia.org/wiki/Differentiation_%28mathematics%29#Rules_for_finding_the_derivative here]).
 +
 
 +
 
 +
* '''M step - weights''': let's start by finding the maximum-likelihood estimates of the weights <math>w_k</math>. But remember the constraint <math>\sum_{k=1}^K w_k = 1</math>. To enforce it, we can use a [http://en.wikipedia.org/wiki/Lagrange_multiplier Lagrange multiplier], <math>\lambda</math>. This means that we now need to maximize the following equation where <math>\Lambda</math> is a Lagrange function (only the part of Q being a function of the weights is kept):
 +
 
 +
<math>\Lambda(w_k,\lambda) = \sum_{i=1}^N \left( \sum_{k=1}^K p(k|i) ln(w_k) \right) + \lambda (1 - \sum_{k=1}^K w_k)</math>
 +
 
 +
As usual, to find the maximum, we derive and equal to zero:
 +
 
 +
<math>\frac{\Lambda}{\partial w_k} = \sum_{i=1}^N \left( p(k|i) \frac{1}{\hat{w}_k^{(t+1)}} \right) - \lambda = 0</math>
 +
 
 +
<math>\hat{w}_k^{(t+1)} = \frac{1}{\lambda} \sum_{i=1}^N p(k|i)</math>
 +
 
 +
Now, to find the multiplier, we go back to the constraint:
  
<math>\frac{\partial l(\theta)}{\partial \mu_k} = \sum_{i=1}^N \frac{1}{f(x_i/\theta)} \frac{\partial f(x_i/\theta)}{\partial \mu_k}</math>
+
<math>\sum_{k=1}^K \hat{w}_k^{(t+1)} = 1 \rightarrow \lambda = \sum_{i=1}^N \sum_{k=1}^K p(k|i) = N</math>
  
As we derive with respect to <math>\mu_k</math>, all the others means <math>\mu_l</math> with <math>l \ne k</math> are constant, and thus disappear:
+
Finally:
  
<math>\frac{\partial f(x_i/\theta)}{\partial \mu_k} = w_k \frac{\partial g(x_i/\mu_k,\sigma_k)}{\partial \mu_k}</math>
+
<math>\hat{w}_k^{(t+1)} = \frac{1}{N} \sum_{i=1}^N p(k|i)</math>
  
And finally:
 
  
<math>\frac{\partial g(x_i/\mu_k,\sigma_k)}{\partial \mu_k} = \frac{\mu_k - x_i}{\sigma_k^2} g(x_i/\mu_k,\sigma_k)</math>
+
* '''M step - means''':
  
Once we put all together, we end up with:
+
<math>\frac{\partial Q}{\partial \mu_k} = \sum_{i=1}^N p(k|i) \frac{\partial ln(\phi(x_i|\mu_k,\sigma_k))}{\partial \mu_k}</math>
  
<math>\frac{\partial l(\theta)}{\partial \mu_k} = \sum_{i=1}^N \frac{1}{\sigma^2} \frac{w_k g(x_i/\mu_k,\sigma_k)}{\sum_{l=1}^K w_l g(x_i/\mu_l,\sigma_l)} (\mu_k - x_i) = \sum_{i=1}^N \frac{1}{\sigma^2} p(k/i) (\mu_k - x_i)</math>
+
<math>\frac{\partial Q}{\partial \mu_k} = \sum_{i=1}^N p(k|i) \frac{1}{\phi(x_i|\mu_k,\sigma_k)} \frac{\partial \phi(x_i|\mu_k,\sigma_k)}{\partial \mu_k}</math>
  
By convention, we note <math>\hat{\mu_k}</math> the maximum-likelihood estimate of <math>\mu_k</math>:
+
<math>\frac{\partial Q}{\partial \mu_k} = 0 = \sum_{i=1}^N p(k|i) (x_i - \hat{\mu}_k^{(t+1)})</math>
  
<math>\frac{\partial l(\theta)}{\partial \mu_k}_{\mu_k=\hat{\mu_k}} = 0</math>
+
Finally:
  
Therefore, we finally obtain:
+
<math>\hat{\mu}_k^{(t+1)} = \frac{\sum_{i=1}^N p(k/i) x_i}{\sum_{i=1}^N p(k/i)}</math>
  
<math>\hat{\mu_k} = \frac{\sum_{i=1}^N p(k/i) x_i}{\sum_{i=1}^N p(k/i)}</math>
 
  
By doing the same kind of algebra, we derive the log-likelihood w.r.t. <math>\sigma_k</math>:
+
* '''M step - variances''': same kind of algebra
  
<math>\frac{\partial l(\theta)}{\partial \sigma_k} = \sum_{i=1}^N p(k/i) (\frac{-1}{\sigma_k} + \frac{(x_i - \mu_k)^2}{\sigma_k^3})</math>
+
<math>\frac{\partial Q}{\partial \sigma_k} = \sum_{i=1}^N p(k/i) (\frac{-1}{\sigma_k} + \frac{(x_i - \mu_k)^2}{\sigma_k^3})</math>
  
And then we obtain the ML estimates for the standard deviation of each cluster:
+
<math>\hat{\sigma}_k^{(t+1)} = \sqrt{\frac{\sum_{i=1}^N p(k/i) (x_i - \hat{\mu}_k^{(t+1)})^2}{\sum_{i=1}^N p(k/i)}}</math>
  
<math>\hat{\sigma_k} = \sqrt{\frac{\sum_{i=1}^N p(k/i) (x_i - \mu_k)^2}{\sum_{i=1}^N p(k/i)}}</math>
 
  
The partial derivative of <math>l(\theta)</math> w.r.t. <math>w_k</math> is tricky because of the constraints on the <math>w_k</math>. But we can handle it by writing them in terms of unconstrained variables <math>\gamma_k</math> ([http://en.wikipedia.org/wiki/Softmax_activation_function softmax function]):
+
* '''M step - weights (2)''': we can write them in terms of unconstrained variables <math>\gamma_k</math> ([http://en.wikipedia.org/wiki/Softmax_activation_function softmax function]):
  
 
<math>w_k = \frac{e^{\gamma_k}}{\sum_{k=1}^K e^{\gamma_k}}</math>
 
<math>w_k = \frac{e^{\gamma_k}}{\sum_{k=1}^K e^{\gamma_k}}</math>
Line 156: Line 186:
 
<math>\frac{\partial l(\theta)}{\partial w_k} = \sum_{i=1}^N (p(k/i) - w_k)</math>
 
<math>\frac{\partial l(\theta)}{\partial w_k} = \sum_{i=1}^N (p(k/i) - w_k)</math>
  
Finally, here are the ML estimates for the mixture weights:
+
Finally:
  
 
<math>\hat{w}_k = \frac{1}{N} \sum_{i=1}^N p(k/i)</math>
 
<math>\hat{w}_k = \frac{1}{N} \sum_{i=1}^N p(k/i)</math>
Line 273: Line 303:
 
   K <- ncol(membership.probas)
 
   K <- ncol(membership.probas)
 
   sapply(1:K, function(k){
 
   sapply(1:K, function(k){
     sqrt(sum(unlist(Map(function(p_ki, x_i){
+
     sqrt(sum(unlist(Map(function(p.ki, x.i){
       p_ki * (x_i - means[k])^2
+
       p.ki * (x.i - means[k])^2
 
     }, membership.probas[,k], data))) /
 
     }, membership.probas[,k], data))) /
 
         sum.membership.probas[k])
 
         sum.membership.probas[k])
Line 365: Line 395:
 
ds <- lapply(1:K, function(k){dnorm(x=rx, mean=res$params$mus[k], sd=res$params$sigmas[k])})
 
ds <- lapply(1:K, function(k){dnorm(x=rx, mean=res$params$mus[k], sd=res$params$sigmas[k])})
 
f <- sapply(1:length(rx), function(i){
 
f <- sapply(1:length(rx), function(i){
   res$params$mix.weights[1] * ds[[1]][i] + res$params$mix.weights[2] * ds[[2]][i] + res$params$mix.weights[3] * s[[3]][i]
+
   res$params$mix.weights[1] * ds[[1]][i] + res$params$mix.weights[2] * ds[[2]][i] + res$params$mix.weights[3] * ds[[3]][i]
 
})
 
})
 
lines(rx, f, col="red", lwd=2)
 
lines(rx, f, col="red", lwd=2)
Line 384: Line 414:
  
 
* '''References''':
 
* '''References''':
** chapter 1 from the PhD thesis of Matthew Stephens (Oxford, 2000)
+
** chapter 1 from the PhD thesis of Matthew Stephens (Oxford, 2000) freely available [http://www.stat.washington.edu/stephens/papers/tabstract.html online]
** chapter 2 from the PhD thesis of Matthew Beal (UCL, 2003)
+
** chapter 2 from the PhD thesis of Matthew Beal (UCL, 2003) freely available [http://www.cse.buffalo.edu/faculty/mbeal/thesis/ online]
** lecture "Mixture Models, Latent Variables and the EM Algorithm" from Cosma Shalizi
+
** lecture "Mixture Models, Latent Variables and the EM Algorithm" from Cosma Shalizi freely available [http://www.stat.cmu.edu/~cshalizi/uADA/12/ online]
 
** book "Introducing Monte Carlo Methods with R" from Robert and and Casella (2009)
 
** book "Introducing Monte Carlo Methods with R" from Robert and and Casella (2009)
  

Revision as of 14:33, 8 May 2013

Owwnotebook icon.png Project name <html><img src="/images/9/94/Report.png" border="0" /></html> Main project page
<html><img src="/images/c/c3/Resultset_previous.png" border="0" /></html>Previous entry<html>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</html>Next entry<html><img src="/images/5/5c/Resultset_next.png" border="0" /></html>

Learn about mixture models and the EM algorithm

(Caution, this is my own quick-and-dirty tutorial, see the references at the end for presentations by professional statisticians.)

  • Motivation: a large part of any scientific activity is about measuring things, in other words collecting data, and it is not infrequent to collect heterogeneous data. It seems therefore natural to say that the samples come from a mixture of clusters. The aim is thus to recover from the data, ie. to infer, (i) how many clusters there are, (ii) what are the features of these clusters, and (iii) from which cluster each sample comes from. In the following, I will focus on points (ii) and (iii).
  • Data: we have N observations, noted Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle X = (x_1, x_2, ..., x_N)} . For the moment, we suppose that each observation Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle x_i} is univariate, ie. each corresponds to only one number.
  • Assumption: let's assume that the data are heterogeneous and that they can be partitioned into Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle K} clusters (in this document, we suppose that Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle K} is known). This means that we expect a subset of the observations to come from cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k=1} , another subset to come from cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k=2} , and so on.


  • Model: technically, we say that the observations were generated according to a density function Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle f} . More precisely, this density is itself a mixture of densities, one per cluster. In our case, we will assume that observations from cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k} are generated from a Normal distribution, which density is here noted Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \phi} , with mean Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mu_k} and standard deviation Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \sigma_k} . Moreover, as we don't know for sure from which cluster a given observation comes from, we define the mixture weight Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle w_k} (also called mixing proportion) to be the probability that any given observation comes from cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k} . As a result, we have the following list of parameters: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta=(w_1,...,w_K,\mu_1,...\mu_K,\sigma_1,...,\sigma_K)} . Finally, for a given observation Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle x_i} , we can write the model:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle f(x_i|\theta) = \sum_{k=1}^{K} w_k \phi(x_i|\mu_k,\sigma_k) = \sum_{k=1}^{K} w_k \frac{1}{\sqrt{2\pi} \sigma_k} \exp \left(-\frac{1}{2}(\frac{x_i - \mu_k}{\sigma_k})^2 \right)}

The constraints are: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \forall k, w_k > 0} and Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \sum_{k=1}^K w_k = 1}


  • Maximum-likelihood: naturally, we can start by maximizing the likelihood in order to estimate the parameters:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle L(\theta) = P(X|\theta) = \prod_{i=1}^N f(x_i|\theta)}

As usual, it's easier to deal with the log-likelihood:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l(\theta) = \sum_{i=1}^N ln \left( f(x_i|\theta) \right) = \sum_{i=1}^N ln \left( \sum_{k=1}^K w_k \phi(x_i; \theta_k) \right)}

Let's take the derivative with respect to one parameter, eg. Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta_l} :

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l}{\partial \theta_l} = \sum_{i=1}^N \frac{1}{\sum_{k=1}^K w_k \phi(x_i; \theta_k)} w_l \frac{\partial \phi(x_i; \theta_l)}{\partial \theta_l}}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l}{\partial \theta_l} = \sum_{i=1}^N \frac{w_l \phi(x_i; \theta_l)}{\sum_{k=1}^K w_k \phi(x_i; \theta_k)} \frac{1}{\phi(x_i; \theta_l)} \frac{\partial \phi(x_i; \theta_l)}{\partial \theta_l}}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l}{\partial \theta_l} = \sum_{i=1}^N \frac{w_l \phi(x_i; \theta_l)}{\sum_{k=1}^K w_k \phi(x_i; \theta_k)} \frac{\partial ln ( \phi(x_i; \theta_l) )}{\partial \theta_l}}

This shows that maximizing the likelihood of a mixture model is like doing a weighted likelihood maximization. However, these weights depend on the parameters we want to estimate! That's why we now switch to the missing-data formulation of the mixture model.


  • Missing data: we introduce the following N latent variables Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Z_1,...,Z_i,...,Z_N} (also called hidden or allocation variables), one for each observation, such that Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Z_i=k} means that observation Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle x_i} belongs to cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k} . In fact, it is much easier to work the equations when defining each Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Z_i} as a vector of length Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle K} , with Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Z_{ik}=1} if observation Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle x_i} belongs to cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k} , and Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Z_{ik}=0} otherwise (indicator variables). Thanks to this, we can reinterpret the mixture weights: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \forall i, P(Z_i=k|\theta)=w_k} . Moreover, we can now define the membership probabilities, one for each observation:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle p(k|i) = P(Z_{ik}=1|x_i,\theta) = \frac{w_k \phi(x_i|\mu_k,\sigma_k)}{\sum_{l=1}^K w_l \phi(x_i|\mu_l,\sigma_l)}}

The observed-data likelihood (also called sometimes "incomplete" or "marginal", even though these appellations are misnomers) is still written the same way:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle L_{obs}(\theta) = P(X|\theta) = \prod_{i=1}^N f(x_i|\theta)}

But now we can also write the augmented-data likelihood, assuming all observations are independent conditionally on their membership:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle L_{aug}(\theta) = P(X,Z|\theta) = \prod_{i=1}^N P(x_i|Z_i,\theta) P(Z_i|\theta) = \prod_{i=1}^N \left( \prod_{k=1}^K \phi(x_i|\mu_k,\sigma_k)^{Z_{ik}} w_k^{Z_{ik}} \right)} .

And here is the augmented-data log-likelihood (useful in the M step of the EM algorithm, see below):

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{aug}(\theta) = \sum_{i=1}^N \left( \sum_{k=1}^K Z_{ik} ln(\phi(x_i|\mu_k,\sigma_k)) + \sum_{k=1}^K Z_{ik} ln(w_k) \right)}

In terms of graphical model, the Gaussian mixture model described here can be represented like this.


  • EM algorithm - definition: the idea is to iterate two steps, starting from randomly-initialized parameters. In the E-step, one computes the conditional expectation of the augmented-data log-likelihood function over the latent variables given the observed data and the parameter estimates from the previous iteration. Second, in the M-step, one maximizes this expected augmented-data log-likelihood function to determine the next iterate of the parameter estimates.
    • E step: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Q(\theta|X,\theta^{(t)}) = \mathbb{E}_{Z|X,\theta^{(t)}} \left[ ln(P(X,Z|\theta))|X,\theta^{(t)} \right] = \int l_{aug} q(Z|X,\theta^{(t)}) dZ}
    • M-step: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta^{(t+1)} = argmax_{\theta} Q(\theta|X,\theta^{(t)})} so that Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \forall \theta \in \Theta, Q(\theta^{(t+1)}|X,\theta^{(t)}) \ge Q(\theta|X,\theta^{(t)})}


  • EM algorithm - theory: stated like this above doesn't necessarily allow oneself to understand it immediately, at least in my case. Hopefully, Matthew Beal presents it in a great and simple way in his PhD thesis (see references at the bottom of the page).

Here is the observed-data log-likelihood:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{obs}(\theta) = \sum_{i=1}^N ln \left( f(x_i|\theta) \right)}

First we introduce the hidden variables by integrating them out:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{obs}(\theta) = \sum_{i=1}^N ln \left( \int p(x_i,z_i|\theta) dz_i \right)}

Then, we use any probability distribution Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q} on these hidden variables (in fact, we use a distinct distribution Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q_{z_i}} for each observation):

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{obs}(\theta) = \sum_{i=1}^N ln \left( \int q_{z_i}(z_i) \frac{p(x_i,z_i|\theta)}{q_{z_i}(z_i)} dz_i \right)}

And here is the great trick, as explained by Beal: "any probability distribution over the hidden variables gives rise to a lower bound on Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{obs}} ". This is due to to the Jensen inequality (the logarithm is concave):

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{obs}(\theta) \ge \sum_{i=1}^N \int q_{z_i}(z_i) ln \left( \frac{p(x_i,z_i|\theta)}{q_{z_i}(z_i)} \right) dz_i = \mathcal{F}(q_{z_1}(z_1), ..., q_{z_N}(z_N), \theta)}

At each iteration, the E step maximizes the lower bound (Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}} ) with respect to the Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q_{z_i}(z_i)} :

  • E step: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q^{(t+1)}_{z_i} \leftarrow argmax_{q_{z_i}} \mathcal{F}(q_z(z), \theta^{(t)}) \forall i}
  • M step: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta^{(t+1)} \leftarrow argmax_\theta \mathcal{F}(q^{(t+1)}_z(z), \theta)}

The E-step amounts to inferring the posterior distribution of the hidden variables Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q^{(t+1)}_{z_i}} given the current parameter Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta^{(t)}} :

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q^{(t+1)}_{z_i}(z_i) = p(z_i | x_i, \theta^{(t)})}

Indeed, the Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q^{(t+1)}_{z_i}(z_i)} make the bound tight (the inequality becomes an equality):

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = \sum_{i=1}^N \int q^{(t+1)}_{z_i}(z_i) ln \left( \frac{p(x_i,z_i|\theta^{(t)})}{q^{(t+1)}_{z_i}(z_i)} \right) dz_i}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = \sum_{i=1}^N \int p(z_i | x_i, \theta^{(t)}) ln \left( \frac{p(x_i,z_i|\theta^{(t)})}{p(z_i | x_i, \theta^{(t)})} \right) dz_i}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = \sum_{i=1}^N \int p(z_i | x_i, \theta^{(t)}) ln \left( \frac{p(x_i|\theta^{(t)}) p(z_i|x_i,\theta^{(t)})}{p(z_i | x_i, \theta^{(t)})} \right) dz_i}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = \sum_{i=1}^N \int p(z_i | x_i, \theta^{(t)}) ln \left( p(x_i|\theta^{(t)}) \right) dz_i}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = \sum_{i=1}^N ln \left( p(x_i|\theta^{(t)}) \right)}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = l_{obs}(\theta^{(t)})}

Then, at the M step, we use these statistics to maximize the new lower bound Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}} with respect to Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta} , and therefore find Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta^{(t+1)}} .


  • EM algorithm - variational: if the posterior distributions Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle p(z_i|x_i,\theta)} are intractable, we can use a variational approach to constrain them to be of a particular, tractable form. In the E step, maximizing Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}} with respect to Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q_{z_i}} is equivalent to minimizing the Kullback-Leibler divergence between the variational distribution Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q(z_i)} and the exact hidden variable posterior Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle p(z_i|x_i,\theta)} :

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle KL[q_{z_i}(z_i) || p(z_i|x_i,\theta)] = \int q_{z_i}(z_i) ln \left( \frac{q_{z_i}(z_i)}{p(z_i|x_i,\theta)} \right)}

As a result, the E step may not always lead to a tight bound.


  • Formulas of both steps: in both steps we need to use Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Q} , whether to evaluate it or maximize it.

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Q(\theta|X,\theta^{(t)}) = \mathbb{E}_{Z|X,\theta^{(t)}} \left[ ln(P(X,Z|\theta))|X,\theta^{(t)} \right]}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Q(\theta|X,\theta^{(t)}) = \mathbb{E}_{Z|X,\theta^{(t)}} \left[ l_{aug}(\theta)|X,\theta^{(t)} \right]}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Q(\theta|X,\theta^{(t)}) = \sum_{i=1}^N \left( \sum_{k=1}^K \mathbb{E}_{Z|X,\theta^{(t)}}[Z_{ik}|x_i,\theta_k^{(t)}] ln(\phi(x_i|\mu_k,\sigma_k)) + \sum_{k=1}^K \mathbb{E}_{Z|X,\theta^{(t)}}[Z_{ik}|x_i,\theta_k^{(t)}] ln(w_k) \right)}


  • Formulas of the E step: as indicated above, the E step consists in evaluating Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Q} , i.e. simply evaluating the conditional expectation over the latent variables of the augmented-data log-likelihood given the observed data and the current estimates of the parameters.

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathbb{E}_{Z|X,\theta^{(t)}}[Z_{ik}|x_i,\theta_k^{(t)}] = P(Z_{ik}=1|x_i,\theta_k^{(t)}) = \frac{w_k^{(t)} \phi(x_i|\mu_k^{(t)},\sigma_k^{(t)})}{\sum_{l=1}^K w_l^{(t)} \phi(x_i|\mu_l^{(t)},\sigma_l^{(t)})} = p(k|i)}


  • Formulas of the M step: in this step, we need to maximize Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Q} (also written Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}} above), w.r.t. each Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta_k} . A few important rules are required to write down the analytical formulas of the MLEs, but only from a high-school level (see here).


  • M step - weights: let's start by finding the maximum-likelihood estimates of the weights Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle w_k} . But remember the constraint Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \sum_{k=1}^K w_k = 1} . To enforce it, we can use a Lagrange multiplier, Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \lambda} . This means that we now need to maximize the following equation where Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \Lambda} is a Lagrange function (only the part of Q being a function of the weights is kept):

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \Lambda(w_k,\lambda) = \sum_{i=1}^N \left( \sum_{k=1}^K p(k|i) ln(w_k) \right) + \lambda (1 - \sum_{k=1}^K w_k)}

As usual, to find the maximum, we derive and equal to zero:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\Lambda}{\partial w_k} = \sum_{i=1}^N \left( p(k|i) \frac{1}{\hat{w}_k^{(t+1)}} \right) - \lambda = 0}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{w}_k^{(t+1)} = \frac{1}{\lambda} \sum_{i=1}^N p(k|i)}

Now, to find the multiplier, we go back to the constraint:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \sum_{k=1}^K \hat{w}_k^{(t+1)} = 1 \rightarrow \lambda = \sum_{i=1}^N \sum_{k=1}^K p(k|i) = N}

Finally:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{w}_k^{(t+1)} = \frac{1}{N} \sum_{i=1}^N p(k|i)}


  • M step - means:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial Q}{\partial \mu_k} = \sum_{i=1}^N p(k|i) \frac{\partial ln(\phi(x_i|\mu_k,\sigma_k))}{\partial \mu_k}}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial Q}{\partial \mu_k} = \sum_{i=1}^N p(k|i) \frac{1}{\phi(x_i|\mu_k,\sigma_k)} \frac{\partial \phi(x_i|\mu_k,\sigma_k)}{\partial \mu_k}}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial Q}{\partial \mu_k} = 0 = \sum_{i=1}^N p(k|i) (x_i - \hat{\mu}_k^{(t+1)})}

Finally:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{\mu}_k^{(t+1)} = \frac{\sum_{i=1}^N p(k/i) x_i}{\sum_{i=1}^N p(k/i)}}


  • M step - variances: same kind of algebra

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial Q}{\partial \sigma_k} = \sum_{i=1}^N p(k/i) (\frac{-1}{\sigma_k} + \frac{(x_i - \mu_k)^2}{\sigma_k^3})}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{\sigma}_k^{(t+1)} = \sqrt{\frac{\sum_{i=1}^N p(k/i) (x_i - \hat{\mu}_k^{(t+1)})^2}{\sum_{i=1}^N p(k/i)}}}


  • M step - weights (2): we can write them in terms of unconstrained variables Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \gamma_k} (softmax function):

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle w_k = \frac{e^{\gamma_k}}{\sum_{k=1}^K e^{\gamma_k}}}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial w_k}{\partial \gamma_j} = \begin{cases} w_k - w_k^2 & \mbox{if }j = k \\ -w_kw_j & \mbox{otherwise} \end{cases}}

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l(\theta)}{\partial w_k} = \sum_{i=1}^N (p(k/i) - w_k)}

Finally:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{w}_k = \frac{1}{N} \sum_{i=1}^N p(k/i)}


  • R code to simulate data: if you read up to there, nothing is better than implementing the EM algorithm yourself!
#' Generate univariate observations from a mixture of Normals
#'
#' @param K number of components
#' @param N number of observations
#' @param gap difference between all component means
GetUnivariateSimulatedData <- function(K=2, N=100, gap=6){
  mus <- seq(0, gap*(K-1), gap)
  sigmas <- runif(n=K, min=0.5, max=1.5)
  tmp <- floor(rnorm(n=K-1, mean=floor(N/K), sd=5))
  ns <- c(tmp, N - sum(tmp))
  clusters <- as.factor(matrix(unlist(lapply(1:K, function(k){rep(k, ns[k])})),
                               ncol=1))
  obs <- matrix(unlist(lapply(1:K, function(k){
    rnorm(n=ns[k], mean=mus[k], sd=sigmas[k])
  })))
  new.order <- sample(1:N, N)
  obs <- obs[new.order]
  rownames(obs) <- NULL
  clusters <- clusters[new.order]
  return(list(obs=obs, clusters=clusters, mus=mus, sigmas=sigmas,
              mix.weights=ns/N))
}

  • R code for the E step:
#' Return probas of latent variables given data and parameters from previous iteration
#'
#' @param data Nx1 vector of observations
#' @param params list which components are mus, sigmas and mix.weights
Estep <- function(data, params){
  GetMembershipProbas(data, params$mus, params$sigmas, params$mix.weights)
}

#' Return the membership probabilities P(zi=k/xi,theta)
#'
#' @param data Nx1 vector of observations
#' @param mus Kx1 vector of means
#' @param sigmas Kx1 vector of std deviations
#' @param mix.weights Kx1 vector of mixture weights w_k=P(zi=k/theta)
#' @return NxK matrix of membership probas
GetMembershipProbas <- function(data, mus, sigmas, mix.weights){
  N <- length(data)
  K <- length(mus)
  tmp <- matrix(unlist(lapply(1:N, function(i){
    x <- data[i]
    norm.const <- sum(unlist(Map(function(mu, sigma, mix.weight){
      mix.weight * GetUnivariateNormalDensity(x, mu, sigma)}, mus, sigmas, mix.weights)))
    unlist(Map(function(mu, sigma, mix.weight){
      mix.weight * GetUnivariateNormalDensity(x, mu, sigma) / norm.const
    }, mus[-K], sigmas[-K], mix.weights[-K]))
  })), ncol=K-1, byrow=TRUE)
  membership.probas <- cbind(tmp, apply(tmp, 1, function(x){1 - sum(x)}))
  names(membership.probas) <- NULL
  return(membership.probas)
}

#' Univariate Normal density
GetUnivariateNormalDensity <- function(x, mu, sigma){
  return( 1/(sigma * sqrt(2*pi)) * exp(-1/(2*sigma^2)*(x-mu)^2) )
}

  • R code for the M step:
#' Return ML estimates of parameters
#'
#' @param data Nx1 vector of observations
#' @param params list which components are mus, sigmas and mix.weights
#' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta)
Mstep <- function(data, params, membership.probas){
  params.new <- list()
  sum.membership.probas <- apply(membership.probas, 2, sum)
  params.new$mus <- GetMlEstimMeans(data, membership.probas,
                                    sum.membership.probas)
  params.new$sigmas <- GetMlEstimStdDevs(data, params.new$mus,
                                         membership.probas,
                                         sum.membership.probas)
  params.new$mix.weights <- GetMlEstimMixWeights(data, membership.probas,
                                                 sum.membership.probas)
  return(params.new)
}

#' Return ML estimates of the means (1 per cluster)
#'
#' @param data Nx1 vector of observations
#' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta)
#' @param sum.membership.probas Kx1 vector of sum per column of matrix above
#' @return Kx1 vector of means
GetMlEstimMeans <- function(data, membership.probas, sum.membership.probas){
  K <- ncol(membership.probas)
  sapply(1:K, function(k){
    sum(unlist(Map("*", membership.probas[,k], data))) /
      sum.membership.probas[k]
  })
}

#' Return ML estimates of the std deviations (1 per cluster)
#'
#' @param data Nx1 vector of observations
#' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta)
#' @param sum.membership.probas Kx1 vector of sum per column of matrix above
#' @return Kx1 vector of std deviations
GetMlEstimStdDevs <- function(data, means, membership.probas,
                              sum.membership.probas){
  K <- ncol(membership.probas)
  sapply(1:K, function(k){
    sqrt(sum(unlist(Map(function(p.ki, x.i){
      p.ki * (x.i - means[k])^2
    }, membership.probas[,k], data))) /
         sum.membership.probas[k])
  })
}

#' Return ML estimates of the mixture weights
#'
#' @param data Nx1 vector of observations
#' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta)
#' @param sum.membership.probas Kx1 vector of sum per column of matrix above
#' @return Kx1 vector of mixture weights
GetMlEstimMixWeights <- function(data, membership.probas,
                                 sum.membership.probas){
  K <- ncol(membership.probas)
  sapply(1:K, function(k){
    1/length(data) * sum.membership.probas[k]
  })
}

  • R code for the log-likelihood:
GetLogLikelihood <- function(data, mus, sigmas, mix.weights){
  loglik <- sum(sapply(data, function(x){
    log(sum(unlist(Map(function(mu, sigma, mix.weight){
      mix.weight * GetUnivariateNormalDensity(x, mu, sigma)
    }, mus, sigmas, mix.weights))))
  }))
  return(loglik)
}

  • R code for the EM loop:
EMalgo <- function(data, params, threshold.convergence=10^(-2), nb.iter=10,
                   verbose=1){
  logliks <- vector()
  i <- 1
  if(verbose > 0) cat(paste("iter ", i, "\n", sep=""))
  membership.probas <- Estep(data, params)
  params <- Mstep(data, params, membership.probas)
  loglik <- GetLogLikelihood(data, params$mus, params$sigmas,
                             params$mix.weights)
  logliks <- append(logliks, loglik)
  while(i < nb.iter){
    i <- i + 1
    if(verbose > 0) cat(paste("iter ", i, "\n", sep=""))
    membership.probas <- Estep(data, params)
    params <- Mstep(data, params, membership.probas)
    loglik <- GetLogLikelihood(data, params$mus, params$sigmas, params$mix.weights)
    if(loglik < logliks[length(logliks)]){
      msg <- paste("the log-likelihood is decreasing:", loglik, "<", logliks[length(logliks)])
      stop(msg, call.=FALSE)
    }
    logliks <- append(logliks, loglik)
    if(abs(logliks[i] - logliks[i-1]) <= threshold.convergence)
      break
  }
  return(list(params=params, membership.probas=membership.probas, logliks=logliks, nb.iters=i))
}

  • Example: and now, let's try it!
## simulate data
K <- 3
N <- 300
simul <- GetUnivariateSimulatedData(K, N)
data <- simul$obs

## run the EM algorithm
params0 <- list(mus=runif(n=K, min=min(data), max=max(data)),
                sigmas=rep(1, K),
                mix.weights=rep(1/K, K))
res <- EMalgo(data, params0, 10^(-3), 1000, 1)

## check its convergence
plot(res$logliks, xlab="iterations", ylab="log-likelihood",
     main="Convergence of the EM algorithm", type="b")

## plot the data along with the inferred densities
png("mixture_univar_em.png")
hist(data, breaks=30, freq=FALSE, col="grey", border="white", ylim=c(0,0.15),
     main="Histogram of data overlaid with densities inferred by EM")
rx <- seq(from=min(data), to=max(data), by=0.1)
ds <- lapply(1:K, function(k){dnorm(x=rx, mean=res$params$mus[k], sd=res$params$sigmas[k])})
f <- sapply(1:length(rx), function(i){
  res$params$mix.weights[1] * ds[[1]][i] + res$params$mix.weights[2] * ds[[2]][i] + res$params$mix.weights[3] * ds[[3]][i]
})
lines(rx, f, col="red", lwd=2)
dev.off()

It seems to work well, which was expected as the clusters are well separated from each other...

Mixture univariate em.png

The classification of each observation can be obtained via the following command:

## get the classification of the observations
memberships <- apply(res$membership.probas, 1, function(x){which(x > 0.5)})
table(memberships)

  • References:
    • chapter 1 from the PhD thesis of Matthew Stephens (Oxford, 2000) freely available online
    • chapter 2 from the PhD thesis of Matthew Beal (UCL, 2003) freely available online
    • lecture "Mixture Models, Latent Variables and the EM Algorithm" from Cosma Shalizi freely available online
    • book "Introducing Monte Carlo Methods with R" from Robert and and Casella (2009)