Difference between revisions of "User:Timothee Flutre/Notebook/Postdoc/2011/12/14"
(→Learn about mixture models and the EM algorithm: add link to graphical model representation) 
(→Learn about mixture models and the EM algorithm: clarify E step) 

Line 118:  Line 118:  
−  * '''  +  * '''Formulas of both steps''': in both steps we need to use <math>Q</math>, whether to evaluate it or maximize it. 
−  
−  
<math>Q(\thetaX,\theta^{(t)}) = \mathbb{E}_{ZX,\theta^{(t)}} \left[ ln(P(X,Z\theta))X,\theta^{(t)} \right]</math>  <math>Q(\thetaX,\theta^{(t)}) = \mathbb{E}_{ZX,\theta^{(t)}} \left[ ln(P(X,Z\theta))X,\theta^{(t)} \right]</math>  
Line 128:  Line 126:  
<math>Q(\thetaX,\theta^{(t)}) = \sum_{i=1}^N \left( \sum_{k=1}^K \mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] ln(\phi(x_i\mu_k,\sigma_k)) + \sum_{k=1}^K \mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] ln(w_k) \right)</math>  <math>Q(\thetaX,\theta^{(t)}) = \sum_{i=1}^N \left( \sum_{k=1}^K \mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] ln(\phi(x_i\mu_k,\sigma_k)) + \sum_{k=1}^K \mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] ln(w_k) \right)</math>  
−  
−  A few important rules are required to write down the analytical formulas of the MLEs, but only from a highschool level (see [http://en.wikipedia.org/wiki/Differentiation_%28mathematics%29#Rules_for_finding_the_derivative here]).  +  * '''Formulas of the E step''': as indicated above, the E step consists in evaluating <math>Q</math>, i.e. simply evaluating the conditional expectation over the latent variables of the augmenteddata loglikelihood given the observed data and the current estimates of the parameters. 
+  
+  <math>\mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] = P(Z_{ik}=1x_i,\theta_k^{(t)}) = \frac{w_k^{(t)} \phi(x_i\mu_k^{(t)},\sigma_k^{(t)})}{\sum_{l=1}^K w_l^{(t)} \phi(x_i\mu_l^{(t)},\sigma_l^{(t)})} = p(ki)</math>  
+  
+  
+  * '''Formulas of the M step''': in this step, we need to maximize <math>Q</math> (also written <math>\mathcal{F}</math> above), w.r.t. each <math>\theta_k</math>. A few important rules are required to write down the analytical formulas of the MLEs, but only from a highschool level (see [http://en.wikipedia.org/wiki/Differentiation_%28mathematics%29#Rules_for_finding_the_derivative here]).  
Revision as of 14:33, 8 May 2013
Project name  <html><img src="/images/9/94/Report.png" border="0" /></html> Main project page <html><img src="/images/c/c3/Resultset_previous.png" border="0" /></html>Previous entry<html> </html>Next entry<html><img src="/images/5/5c/Resultset_next.png" border="0" /></html> 
Learn about mixture models and the EM algorithm(Caution, this is my own quickanddirty tutorial, see the references at the end for presentations by professional statisticians.)
Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle f(x_i\theta) = \sum_{k=1}^{K} w_k \phi(x_i\mu_k,\sigma_k) = \sum_{k=1}^{K} w_k \frac{1}{\sqrt{2\pi} \sigma_k} \exp \left(\frac{1}{2}(\frac{x_i  \mu_k}{\sigma_k})^2 \right)} The constraints are: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \forall k, w_k > 0} and Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \sum_{k=1}^K w_k = 1}
Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle L(\theta) = P(X\theta) = \prod_{i=1}^N f(x_i\theta)} As usual, it's easier to deal with the loglikelihood: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l(\theta) = \sum_{i=1}^N ln \left( f(x_i\theta) \right) = \sum_{i=1}^N ln \left( \sum_{k=1}^K w_k \phi(x_i; \theta_k) \right)} Let's take the derivative with respect to one parameter, eg. Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta_l} : Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l}{\partial \theta_l} = \sum_{i=1}^N \frac{1}{\sum_{k=1}^K w_k \phi(x_i; \theta_k)} w_l \frac{\partial \phi(x_i; \theta_l)}{\partial \theta_l}} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l}{\partial \theta_l} = \sum_{i=1}^N \frac{w_l \phi(x_i; \theta_l)}{\sum_{k=1}^K w_k \phi(x_i; \theta_k)} \frac{1}{\phi(x_i; \theta_l)} \frac{\partial \phi(x_i; \theta_l)}{\partial \theta_l}} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l}{\partial \theta_l} = \sum_{i=1}^N \frac{w_l \phi(x_i; \theta_l)}{\sum_{k=1}^K w_k \phi(x_i; \theta_k)} \frac{\partial ln ( \phi(x_i; \theta_l) )}{\partial \theta_l}} This shows that maximizing the likelihood of a mixture model is like doing a weighted likelihood maximization. However, these weights depend on the parameters we want to estimate! That's why we now switch to the missingdata formulation of the mixture model.
Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle p(ki) = P(Z_{ik}=1x_i,\theta) = \frac{w_k \phi(x_i\mu_k,\sigma_k)}{\sum_{l=1}^K w_l \phi(x_i\mu_l,\sigma_l)}} The observeddata likelihood (also called sometimes "incomplete" or "marginal", even though these appellations are misnomers) is still written the same way: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle L_{obs}(\theta) = P(X\theta) = \prod_{i=1}^N f(x_i\theta)} But now we can also write the augmenteddata likelihood, assuming all observations are independent conditionally on their membership: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle L_{aug}(\theta) = P(X,Z\theta) = \prod_{i=1}^N P(x_iZ_i,\theta) P(Z_i\theta) = \prod_{i=1}^N \left( \prod_{k=1}^K \phi(x_i\mu_k,\sigma_k)^{Z_{ik}} w_k^{Z_{ik}} \right)} . And here is the augmenteddata loglikelihood (useful in the M step of the EM algorithm, see below): Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{aug}(\theta) = \sum_{i=1}^N \left( \sum_{k=1}^K Z_{ik} ln(\phi(x_i\mu_k,\sigma_k)) + \sum_{k=1}^K Z_{ik} ln(w_k) \right)} In terms of graphical model, the Gaussian mixture model described here can be represented like this.
Here is the observeddata loglikelihood: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{obs}(\theta) = \sum_{i=1}^N ln \left( f(x_i\theta) \right)} First we introduce the hidden variables by integrating them out: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{obs}(\theta) = \sum_{i=1}^N ln \left( \int p(x_i,z_i\theta) dz_i \right)} Then, we use any probability distribution Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q} on these hidden variables (in fact, we use a distinct distribution Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q_{z_i}} for each observation): Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{obs}(\theta) = \sum_{i=1}^N ln \left( \int q_{z_i}(z_i) \frac{p(x_i,z_i\theta)}{q_{z_i}(z_i)} dz_i \right)} And here is the great trick, as explained by Beal: "any probability distribution over the hidden variables gives rise to a lower bound on Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{obs}} ". This is due to to the Jensen inequality (the logarithm is concave): Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l_{obs}(\theta) \ge \sum_{i=1}^N \int q_{z_i}(z_i) ln \left( \frac{p(x_i,z_i\theta)}{q_{z_i}(z_i)} \right) dz_i = \mathcal{F}(q_{z_1}(z_1), ..., q_{z_N}(z_N), \theta)} At each iteration, the E step maximizes the lower bound (Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}} ) with respect to the Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q_{z_i}(z_i)} :
The Estep amounts to inferring the posterior distribution of the hidden variables Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q^{(t+1)}_{z_i}} given the current parameter Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta^{(t)}} : Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q^{(t+1)}_{z_i}(z_i) = p(z_i  x_i, \theta^{(t)})} Indeed, the Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle q^{(t+1)}_{z_i}(z_i)} make the bound tight (the inequality becomes an equality): Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = \sum_{i=1}^N \int q^{(t+1)}_{z_i}(z_i) ln \left( \frac{p(x_i,z_i\theta^{(t)})}{q^{(t+1)}_{z_i}(z_i)} \right) dz_i} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = \sum_{i=1}^N \int p(z_i  x_i, \theta^{(t)}) ln \left( \frac{p(x_i,z_i\theta^{(t)})}{p(z_i  x_i, \theta^{(t)})} \right) dz_i} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = \sum_{i=1}^N \int p(z_i  x_i, \theta^{(t)}) ln \left( \frac{p(x_i\theta^{(t)}) p(z_ix_i,\theta^{(t)})}{p(z_i  x_i, \theta^{(t)})} \right) dz_i} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = \sum_{i=1}^N \int p(z_i  x_i, \theta^{(t)}) ln \left( p(x_i\theta^{(t)}) \right) dz_i} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = \sum_{i=1}^N ln \left( p(x_i\theta^{(t)}) \right)} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}(q^{(t+1)}_z(z), \theta^{(t)}) = l_{obs}(\theta^{(t)})} Then, at the M step, we use these statistics to maximize the new lower bound Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathcal{F}} with respect to Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta} , and therefore find Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta^{(t+1)}} .
Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle KL[q_{z_i}(z_i)  p(z_ix_i,\theta)] = \int q_{z_i}(z_i) ln \left( \frac{q_{z_i}(z_i)}{p(z_ix_i,\theta)} \right)} As a result, the E step may not always lead to a tight bound.
Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Q(\thetaX,\theta^{(t)}) = \mathbb{E}_{ZX,\theta^{(t)}} \left[ ln(P(X,Z\theta))X,\theta^{(t)} \right]} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Q(\thetaX,\theta^{(t)}) = \mathbb{E}_{ZX,\theta^{(t)}} \left[ l_{aug}(\theta)X,\theta^{(t)} \right]} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Q(\thetaX,\theta^{(t)}) = \sum_{i=1}^N \left( \sum_{k=1}^K \mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] ln(\phi(x_i\mu_k,\sigma_k)) + \sum_{k=1}^K \mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] ln(w_k) \right)}
Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mathbb{E}_{ZX,\theta^{(t)}}[Z_{ik}x_i,\theta_k^{(t)}] = P(Z_{ik}=1x_i,\theta_k^{(t)}) = \frac{w_k^{(t)} \phi(x_i\mu_k^{(t)},\sigma_k^{(t)})}{\sum_{l=1}^K w_l^{(t)} \phi(x_i\mu_l^{(t)},\sigma_l^{(t)})} = p(ki)}
Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \Lambda(w_k,\lambda) = \sum_{i=1}^N \left( \sum_{k=1}^K p(ki) ln(w_k) \right) + \lambda (1  \sum_{k=1}^K w_k)} As usual, to find the maximum, we derive and equal to zero: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\Lambda}{\partial w_k} = \sum_{i=1}^N \left( p(ki) \frac{1}{\hat{w}_k^{(t+1)}} \right)  \lambda = 0} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{w}_k^{(t+1)} = \frac{1}{\lambda} \sum_{i=1}^N p(ki)} Now, to find the multiplier, we go back to the constraint: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \sum_{k=1}^K \hat{w}_k^{(t+1)} = 1 \rightarrow \lambda = \sum_{i=1}^N \sum_{k=1}^K p(ki) = N} Finally: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{w}_k^{(t+1)} = \frac{1}{N} \sum_{i=1}^N p(ki)}
Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial Q}{\partial \mu_k} = \sum_{i=1}^N p(ki) \frac{\partial ln(\phi(x_i\mu_k,\sigma_k))}{\partial \mu_k}} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial Q}{\partial \mu_k} = \sum_{i=1}^N p(ki) \frac{1}{\phi(x_i\mu_k,\sigma_k)} \frac{\partial \phi(x_i\mu_k,\sigma_k)}{\partial \mu_k}} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial Q}{\partial \mu_k} = 0 = \sum_{i=1}^N p(ki) (x_i  \hat{\mu}_k^{(t+1)})} Finally: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{\mu}_k^{(t+1)} = \frac{\sum_{i=1}^N p(k/i) x_i}{\sum_{i=1}^N p(k/i)}}
Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial Q}{\partial \sigma_k} = \sum_{i=1}^N p(k/i) (\frac{1}{\sigma_k} + \frac{(x_i  \mu_k)^2}{\sigma_k^3})} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{\sigma}_k^{(t+1)} = \sqrt{\frac{\sum_{i=1}^N p(k/i) (x_i  \hat{\mu}_k^{(t+1)})^2}{\sum_{i=1}^N p(k/i)}}}
Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle w_k = \frac{e^{\gamma_k}}{\sum_{k=1}^K e^{\gamma_k}}} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial w_k}{\partial \gamma_j} = \begin{cases} w_k  w_k^2 & \mbox{if }j = k \\ w_kw_j & \mbox{otherwise} \end{cases}} Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l(\theta)}{\partial w_k} = \sum_{i=1}^N (p(k/i)  w_k)} Finally: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{w}_k = \frac{1}{N} \sum_{i=1}^N p(k/i)}
#' Generate univariate observations from a mixture of Normals #' #' @param K number of components #' @param N number of observations #' @param gap difference between all component means GetUnivariateSimulatedData < function(K=2, N=100, gap=6){ mus < seq(0, gap*(K1), gap) sigmas < runif(n=K, min=0.5, max=1.5) tmp < floor(rnorm(n=K1, mean=floor(N/K), sd=5)) ns < c(tmp, N  sum(tmp)) clusters < as.factor(matrix(unlist(lapply(1:K, function(k){rep(k, ns[k])})), ncol=1)) obs < matrix(unlist(lapply(1:K, function(k){ rnorm(n=ns[k], mean=mus[k], sd=sigmas[k]) }))) new.order < sample(1:N, N) obs < obs[new.order] rownames(obs) < NULL clusters < clusters[new.order] return(list(obs=obs, clusters=clusters, mus=mus, sigmas=sigmas, mix.weights=ns/N)) }
#' Return probas of latent variables given data and parameters from previous iteration #' #' @param data Nx1 vector of observations #' @param params list which components are mus, sigmas and mix.weights Estep < function(data, params){ GetMembershipProbas(data, params$mus, params$sigmas, params$mix.weights) } #' Return the membership probabilities P(zi=k/xi,theta) #' #' @param data Nx1 vector of observations #' @param mus Kx1 vector of means #' @param sigmas Kx1 vector of std deviations #' @param mix.weights Kx1 vector of mixture weights w_k=P(zi=k/theta) #' @return NxK matrix of membership probas GetMembershipProbas < function(data, mus, sigmas, mix.weights){ N < length(data) K < length(mus) tmp < matrix(unlist(lapply(1:N, function(i){ x < data[i] norm.const < sum(unlist(Map(function(mu, sigma, mix.weight){ mix.weight * GetUnivariateNormalDensity(x, mu, sigma)}, mus, sigmas, mix.weights))) unlist(Map(function(mu, sigma, mix.weight){ mix.weight * GetUnivariateNormalDensity(x, mu, sigma) / norm.const }, mus[K], sigmas[K], mix.weights[K])) })), ncol=K1, byrow=TRUE) membership.probas < cbind(tmp, apply(tmp, 1, function(x){1  sum(x)})) names(membership.probas) < NULL return(membership.probas) } #' Univariate Normal density GetUnivariateNormalDensity < function(x, mu, sigma){ return( 1/(sigma * sqrt(2*pi)) * exp(1/(2*sigma^2)*(xmu)^2) ) }
#' Return ML estimates of parameters #' #' @param data Nx1 vector of observations #' @param params list which components are mus, sigmas and mix.weights #' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta) Mstep < function(data, params, membership.probas){ params.new < list() sum.membership.probas < apply(membership.probas, 2, sum) params.new$mus < GetMlEstimMeans(data, membership.probas, sum.membership.probas) params.new$sigmas < GetMlEstimStdDevs(data, params.new$mus, membership.probas, sum.membership.probas) params.new$mix.weights < GetMlEstimMixWeights(data, membership.probas, sum.membership.probas) return(params.new) } #' Return ML estimates of the means (1 per cluster) #' #' @param data Nx1 vector of observations #' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta) #' @param sum.membership.probas Kx1 vector of sum per column of matrix above #' @return Kx1 vector of means GetMlEstimMeans < function(data, membership.probas, sum.membership.probas){ K < ncol(membership.probas) sapply(1:K, function(k){ sum(unlist(Map("*", membership.probas[,k], data))) / sum.membership.probas[k] }) } #' Return ML estimates of the std deviations (1 per cluster) #' #' @param data Nx1 vector of observations #' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta) #' @param sum.membership.probas Kx1 vector of sum per column of matrix above #' @return Kx1 vector of std deviations GetMlEstimStdDevs < function(data, means, membership.probas, sum.membership.probas){ K < ncol(membership.probas) sapply(1:K, function(k){ sqrt(sum(unlist(Map(function(p.ki, x.i){ p.ki * (x.i  means[k])^2 }, membership.probas[,k], data))) / sum.membership.probas[k]) }) } #' Return ML estimates of the mixture weights #' #' @param data Nx1 vector of observations #' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta) #' @param sum.membership.probas Kx1 vector of sum per column of matrix above #' @return Kx1 vector of mixture weights GetMlEstimMixWeights < function(data, membership.probas, sum.membership.probas){ K < ncol(membership.probas) sapply(1:K, function(k){ 1/length(data) * sum.membership.probas[k] }) }
GetLogLikelihood < function(data, mus, sigmas, mix.weights){ loglik < sum(sapply(data, function(x){ log(sum(unlist(Map(function(mu, sigma, mix.weight){ mix.weight * GetUnivariateNormalDensity(x, mu, sigma) }, mus, sigmas, mix.weights)))) })) return(loglik) }
EMalgo < function(data, params, threshold.convergence=10^(2), nb.iter=10, verbose=1){ logliks < vector() i < 1 if(verbose > 0) cat(paste("iter ", i, "\n", sep="")) membership.probas < Estep(data, params) params < Mstep(data, params, membership.probas) loglik < GetLogLikelihood(data, params$mus, params$sigmas, params$mix.weights) logliks < append(logliks, loglik) while(i < nb.iter){ i < i + 1 if(verbose > 0) cat(paste("iter ", i, "\n", sep="")) membership.probas < Estep(data, params) params < Mstep(data, params, membership.probas) loglik < GetLogLikelihood(data, params$mus, params$sigmas, params$mix.weights) if(loglik < logliks[length(logliks)]){ msg < paste("the loglikelihood is decreasing:", loglik, "<", logliks[length(logliks)]) stop(msg, call.=FALSE) } logliks < append(logliks, loglik) if(abs(logliks[i]  logliks[i1]) <= threshold.convergence) break } return(list(params=params, membership.probas=membership.probas, logliks=logliks, nb.iters=i)) }
## simulate data K < 3 N < 300 simul < GetUnivariateSimulatedData(K, N) data < simul$obs ## run the EM algorithm params0 < list(mus=runif(n=K, min=min(data), max=max(data)), sigmas=rep(1, K), mix.weights=rep(1/K, K)) res < EMalgo(data, params0, 10^(3), 1000, 1) ## check its convergence plot(res$logliks, xlab="iterations", ylab="loglikelihood", main="Convergence of the EM algorithm", type="b") ## plot the data along with the inferred densities png("mixture_univar_em.png") hist(data, breaks=30, freq=FALSE, col="grey", border="white", ylim=c(0,0.15), main="Histogram of data overlaid with densities inferred by EM") rx < seq(from=min(data), to=max(data), by=0.1) ds < lapply(1:K, function(k){dnorm(x=rx, mean=res$params$mus[k], sd=res$params$sigmas[k])}) f < sapply(1:length(rx), function(i){ res$params$mix.weights[1] * ds[[1]][i] + res$params$mix.weights[2] * ds[[2]][i] + res$params$mix.weights[3] * ds[[3]][i] }) lines(rx, f, col="red", lwd=2) dev.off() It seems to work well, which was expected as the clusters are well separated from each other... The classification of each observation can be obtained via the following command: ## get the classification of the observations memberships < apply(res$membership.probas, 1, function(x){which(x > 0.5)}) table(memberships)
