Difference between revisions of "User:Timothee Flutre/Notebook/Postdoc/2011/12/14"

From OpenWetWare
Jump to: navigation, search
(Learn about mixture models and the EM algorithm: add option gap + add code to get classif)
(Learn about mixture models and the EM algorithm: add constraints mix weights)
Line 19: Line 19:
  
 
<math>f(x_i/\theta) = \sum_{k=1}^{K} w_k g(x_i/\mu_k,\sigma_k) = \sum_{k=1}^{K} w_k \frac{1}{\sqrt{2\pi} \sigma_k} \exp \left(-\frac{1}{2}(\frac{x_i - \mu_k}{\sigma_k})^2 \right)</math>
 
<math>f(x_i/\theta) = \sum_{k=1}^{K} w_k g(x_i/\mu_k,\sigma_k) = \sum_{k=1}^{K} w_k \frac{1}{\sqrt{2\pi} \sigma_k} \exp \left(-\frac{1}{2}(\frac{x_i - \mu_k}{\sigma_k})^2 \right)</math>
 +
 +
The constraints are:
 +
<math>\forall k, w_k > 0</math> and <math>\sum_{k=1}^K w_k = 1</math>
  
 
* '''Missing data''': it is worth noting that a big piece of information is lacking here. We aim at finding the parameters defining the mixture, but we don't know from which cluster each observation is coming! That's why we need to introduce the following N [http://en.wikipedia.org/wiki/Latent_variable latent variables] <math>Z_1,...,Z_i,...,Z_N</math>, one for each observation, such that <math>Z_i=k</math> means that observation <math>x_i</math> belongs to cluster <math>k</math> ([http://en.wikipedia.org/wiki/Dummy_variable_%28statistics%29 indicators]). This is called the "missing data formulation" of the mixture model. Thanks to this, we can reinterpret the mixture weights: <math>w_k = P(Z_i=k/\theta)</math>. Moreover, we can now define the membership probabilities, one for each observation:
 
* '''Missing data''': it is worth noting that a big piece of information is lacking here. We aim at finding the parameters defining the mixture, but we don't know from which cluster each observation is coming! That's why we need to introduce the following N [http://en.wikipedia.org/wiki/Latent_variable latent variables] <math>Z_1,...,Z_i,...,Z_N</math>, one for each observation, such that <math>Z_i=k</math> means that observation <math>x_i</math> belongs to cluster <math>k</math> ([http://en.wikipedia.org/wiki/Dummy_variable_%28statistics%29 indicators]). This is called the "missing data formulation" of the mixture model. Thanks to this, we can reinterpret the mixture weights: <math>w_k = P(Z_i=k/\theta)</math>. Moreover, we can now define the membership probabilities, one for each observation:

Revision as of 07:29, 4 January 2012

Owwnotebook icon.png Project name <html><img src="/images/9/94/Report.png" border="0" /></html> Main project page
<html><img src="/images/c/c3/Resultset_previous.png" border="0" /></html>Previous entry<html>&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;</html>Next entry<html><img src="/images/5/5c/Resultset_next.png" border="0" /></html>

Learn about mixture models and the EM algorithm

(Caution, this is my own quick-and-dirty tutorial, see the references at the end for presentations by professional statisticians.)

  • Motivation: a large part of any scientific activity is about measuring things, in other words collecting data, and it is not infrequent to collect heterogeneous data. For instance, we measure the height of individuals without recording their gender, we measure the levels of expression of a gene in several individuals without recording which ones are healthy and which ones are sick, etc. It seems therefore natural to say that the samples come from a mixture of clusters. The aim is then to recover from the data, ie. to infer, (i) the values of the parameters of the probability distribution of each cluster, and (ii) from which cluster each sample comes from.
  • Data: we have N observations, noted Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle X = (x_1, x_2, ..., x_N)} . For the moment, we suppose that each observation Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle x_i} is univariate, ie. each corresponds to only one number.
  • Hypothesis: let's assume that the data are heterogeneous and that they can be partitioned into Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle K} clusters (see examples above). This means that we expect a subset of the observations to come from cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k=1} , another subset to come from cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k=2} , and so on.
  • Model: technically, we say that the observations were generated according to a density function Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle f} . More precisely, this density is itself a mixture of densities, one per cluster. In our case, we will assume that each cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k} corresponds to a Normal distribution, which density is here noted Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle g} , with mean Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mu_k} and standard deviation Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \sigma_k} . Moreover, as we don't know for sure from which cluster a given observation comes from, we define the mixture weight Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle w_k} to be the probability that any given observation comes from cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k} . As a result, we have the following list of parameters: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \theta=(w_1,...,w_K,\mu_1,...\mu_K,\sigma_1,...,\sigma_K)} . Finally, for a given observation Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle x_i} , we can write the model:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle f(x_i/\theta) = \sum_{k=1}^{K} w_k g(x_i/\mu_k,\sigma_k) = \sum_{k=1}^{K} w_k \frac{1}{\sqrt{2\pi} \sigma_k} \exp \left(-\frac{1}{2}(\frac{x_i - \mu_k}{\sigma_k})^2 \right)}

The constraints are: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \forall k, w_k > 0} and Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \sum_{k=1}^K w_k = 1}

  • Missing data: it is worth noting that a big piece of information is lacking here. We aim at finding the parameters defining the mixture, but we don't know from which cluster each observation is coming! That's why we need to introduce the following N latent variables Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Z_1,...,Z_i,...,Z_N} , one for each observation, such that Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle Z_i=k} means that observation Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle x_i} belongs to cluster Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle k} (indicators). This is called the "missing data formulation" of the mixture model. Thanks to this, we can reinterpret the mixture weights: Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle w_k = P(Z_i=k/\theta)} . Moreover, we can now define the membership probabilities, one for each observation:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle p(k/i) = P(Z_i=k/x_i,\theta) = \frac{w_k g(x_i/\mu_k,\sigma_k)}{\sum_{l=1}^K w_l g(x_i/\mu_l,\sigma_l)}}

We can now write the complete likelihood, ie. the likelihood of the augmented model (even if we don't need it in the following), where Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle I_k = \{i / Z_i = k\}} : Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle L_{comp}(\theta) = P(X,Z/\theta) = P(X/Z,\theta) P(Z/\theta) = \left( \prod_{k=1}^K \prod_{i \in I_k} g(x_i/\mu_k,\sigma_k) \right) \prod_{i=1}^N P(Z_i/\theta)} .

And, more useful, the incomplete (or marginal) likelihood, assuming all observations are independent:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle L_{incomp}(\theta) = P(X/\theta) = \prod_{i=1}^N f(x_i/\theta)}

  • ML estimation: we want to find the values of the parameters that maximize the likelihood. This reduces to (i) differentiating the log-likelihood with respect to each parameter, and then (ii) finding the value at which each partial derivative is zero. Instead of maximizing the likelihood, we maximize its logarithm, noted Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l(\theta)} . It gives the same solution because the log is monotonically increasing, but it's easier to derive the log-likelihood than the likelihood. Here is the whole formula for the (incomplete) log-likelihood:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l(\theta) = \sum_{i=1}^N log(f(x_i/\theta)) = \sum_{i=1}^N log \left( \sum_{k=1}^{K} w_k \frac{1}{\sqrt{2\pi} \sigma_k} \exp^{-\frac{1}{2}(\frac{x_i - \mu_k}{\sigma_k})^2} \right)}

  • MLE analytical formulae: a few important rules are required, but only from a high-school level in maths (see here). Let's start by finding the maximum-likelihood estimates of the mean of each cluster:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l(\theta)}{\partial \mu_k} = \sum_{i=1}^N \frac{1}{f(x_i/\theta)} \frac{\partial f(x_i/\theta)}{\partial \mu_k}}

As we derive with respect to Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mu_k} , all the others means Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mu_l} with Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l \ne k} are constant, and thus disappear:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial f(x_i/\theta)}{\partial \mu_k} = w_k \frac{\partial g(x_i/\mu_k,\sigma_k)}{\partial \mu_k}}

And finally:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial g(x_i/\mu_k,\sigma_k)}{\partial \mu_k} = \frac{\mu_k - x_i}{\sigma_k^2} g(x_i/\mu_k,\sigma_k)}

Once we put all together, we end up with:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l(\theta)}{\partial \mu_k} = \sum_{i=1}^N \frac{1}{\sigma^2} \frac{w_k g(x_i/\mu_k,\sigma_k)}{\sum_{l=1}^K w_l g(x_i/\mu_l,\sigma_l)} (\mu_k - x_i) = \sum_{i=1}^N \frac{1}{\sigma^2} p(k/i) (\mu_k - x_i)}

By convention, we note Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{\mu_k}} the maximum-likelihood estimate of Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \mu_k} :

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l(\theta)}{\partial \mu_k}_{\mu_k=\hat{\mu_k}} = 0}

Therefore, we finally obtain:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{\mu_k} = \frac{\sum_{i=1}^N p(k/i) x_i}{\sum_{i=1}^N p(k/i)}}

By doing the same kind of algebra, we derive the log-likelihood w.r.t. Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \sigma_k} :

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l(\theta)}{\partial \sigma_k} = \sum_{i=1}^N p(k/i) (\frac{-1}{\sigma_k} + \frac{(x_i - \mu_k)^2}{\sigma_k^3})}

And then we obtain the ML estimates for the standard deviation of each cluster:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{\sigma_k} = \sqrt{\frac{\sum_{i=1}^N p(k/i) (x_i - \mu_k)^2}{\sum_{i=1}^N p(k/i)}}}

The partial derivative of Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle l(\theta)} w.r.t. Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle w_k} is tricky. ... <TO DO> ...

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \frac{\partial l(\theta)}{\partial w_k} = \sum_{i=1}^N (p(k/i) - w_k)}

Finally, here are the ML estimates for the mixture weights:

Failed to parse (MathML with SVG or PNG fallback (recommended for modern browsers and accessibility tools): Invalid response ("Math extension cannot connect to Restbase.") from server "https://api.formulasearchengine.com/v1/":): {\displaystyle \hat{w}_k = \frac{1}{N} \sum_{i=1}^N p(k/i)}

  • EM algorithm: ... <TO DO> ...
  • R code to simulate data:
#' Generate univariate observations from a mixture of Normals
#'
#' @param K number of components
#' @param N number of observations
#' @param gap difference between all component means
GetUnivariateSimulatedData <- function(K=2, N=100, gap=6){
  mus <- seq(0, gap*(K-1), gap)
  sigmas <- runif(n=K, min=0.5, max=1.5)
  tmp <- floor(rnorm(n=K-1, mean=floor(N/K), sd=5))
  ns <- c(tmp, N - sum(tmp))
  clusters <- as.factor(matrix(unlist(lapply(1:K, function(k){rep(k, ns[k])})),
                               ncol=1))
  obs <- matrix(unlist(lapply(1:K, function(k){
    rnorm(n=ns[k], mean=mus[k], sd=sigmas[k])
  })))
  new.order <- sample(1:N, N)
  obs <- obs[new.order]
  rownames(obs) <- NULL
  clusters <- clusters[new.order]
  return(list(obs=obs, clusters=clusters, mus=mus, sigmas=sigmas,
              mix.weights=ns/N))
}
  • R code for the E step:
#' Return probas of latent variables given data and parameters from previous iteration
#'
#' @param data Nx1 vector of observations
#' @param params list which components are mus, sigmas and mix.weights
Estep <- function(data, params){
  GetMembershipProbas(data, params$mus, params$sigmas, params$mix.weights)
}
#' Return the membership probabilities P(zi=k/xi,theta)
#'
#' @param data Nx1 vector of observations
#' @param mus Kx1 vector of means
#' @param sigmas Kx1 vector of std deviations
#' @param mix.weights Kx1 vector of mixture weights w_k=P(zi=k/theta)
#' @return NxK matrix of membership probas
GetMembershipProbas <- function(data, mus, sigmas, mix.weights){
  N <- length(data)
  K <- length(mus)
  tmp <- matrix(unlist(lapply(1:N, function(i){
    x <- data[i]
    norm.const <- sum(unlist(Map(function(mu, sigma, mix.weight){
      mix.weight * GetUnivariateNormalDensity(x, mu, sigma)}, mus, sigmas, mix.weights)))
    unlist(Map(function(mu, sigma, mix.weight){
      mix.weight * GetUnivariateNormalDensity(x, mu, sigma) / norm.const
    }, mus[-K], sigmas[-K], mix.weights[-K]))
  })), ncol=K-1, byrow=TRUE)
  membership.probas <- cbind(tmp, apply(tmp, 1, function(x){1 - sum(x)}))
  names(membership.probas) <- NULL
  return(membership.probas)
}
#' Univariate Normal density
GetUnivariateNormalDensity <- function(x, mu, sigma){
  return( 1/(sigma * sqrt(2*pi)) * exp(-1/(2*sigma^2)*(x-mu)^2) )
}
  • R code for the M step:
#' Return ML estimates of parameters
#'
#' @param data Nx1 vector of observations
#' @param params list which components are mus, sigmas and mix.weights
#' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta)
Mstep <- function(data, params, membership.probas){
  params.new <- list()
  sum.membership.probas <- apply(membership.probas, 2, sum)
  params.new$mus <- GetMlEstimMeans(data, membership.probas,
                                    sum.membership.probas)
  params.new$sigmas <- GetMlEstimStdDevs(data, params.new$mus,
                                         membership.probas,
                                         sum.membership.probas)
  params.new$mix.weights <- GetMlEstimMixWeights(data, membership.probas,
                                                 sum.membership.probas)
  return(params.new)
}
#' Return ML estimates of the means (1 per cluster)
#'
#' @param data Nx1 vector of observations
#' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta)
#' @param sum.membership.probas Kx1 vector of sum per column of matrix above
#' @return Kx1 vector of means
GetMlEstimMeans <- function(data, membership.probas, sum.membership.probas){
  K <- ncol(membership.probas)
  sapply(1:K, function(k){
    sum(unlist(Map("*", membership.probas[,k], data))) /
      sum.membership.probas[k]
  })
}
#' Return ML estimates of the std deviations (1 per cluster)
#'
#' @param data Nx1 vector of observations
#' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta)
#' @param sum.membership.probas Kx1 vector of sum per column of matrix above
#' @return Kx1 vector of std deviations
GetMlEstimStdDevs <- function(data, means, membership.probas,
                              sum.membership.probas){
  K <- ncol(membership.probas)
  sapply(1:K, function(k){
    sqrt(sum(unlist(Map(function(p_ki, x_i){
      p_ki * (x_i - means[k])^2
    }, membership.probas[,k], data))) /
         sum.membership.probas[k])
  })
}
#' Return ML estimates of the mixture weights
#'
#' @param data Nx1 vector of observations
#' @param membership.probas NxK matrix with entry i,k being P(zi=k/xi,theta)
#' @param sum.membership.probas Kx1 vector of sum per column of matrix above
#' @return Kx1 vector of mixture weights
GetMlEstimMixWeights <- function(data, membership.probas,
                                 sum.membership.probas){
  K <- ncol(membership.probas)
  sapply(1:K, function(k){
    1/length(data) * sum.membership.probas[k]
  })
}
  • R code for the EM loop:

... <TO DO> ...

  • Example: and now, let's try it!
## simulate data
K <- 3
N <- 300
simul <- GetUnivariateSimulatedData(K, N)
data <- simul$obs
## run the EM algorithm
params0 <- list(mus=runif(n=K, min=min(data), max=max(data)),
                sigmas=rep(1, K),
                mix.weights=rep(1/K, K))
res <- EMalgo(data, params0, 10^(-3), 1000, 1)
## check its convergence
plot(res$logliks, xlab="iterations", ylab="log-likelihood",
     main="Convergence of the EM algorithm", type="b")
## plot the data along with the inferred densities
png("mixture_univar_em.png")
hist(data, breaks=30, freq=FALSE, col="grey", border="white", ylim=c(0,0.15),
     main="Histogram of data overlaid with densities inferred by EM")
rx <- seq(from=min(data), to=max(data), by=0.1)
ds <- lapply(1:K, function(k){dnorm(x=rx, mean=res$params$mus[k], sd=res$params$sigmas[k])})
f <- sapply(1:length(rx), function(i){
  res$params$mix.weights[1] * ds[[1]][i] + res$params$mix.weights[2] * ds[[2]][i] + res$params$mix.weights[3] * ds[[3]][i]
})
lines(rx, f, col="red", lwd=2)
dev.off()

It seems to work well, which was expected as the clusters are well separated from each other...

Mixture univariate em.png

The classification of each observation can be obtained via the following command:

## get the classification of the observations
memberships <- apply(res$membership.probas, 1, function(x){which(x > 0.5)})
table(memberships)
  • References:
    • introduction (ch.1) of the PhD thesis from Matthew Stephens (Oxford, 2000)
    • tutorial from Carlo Tomasi (Duke University)
    • book "Introducing Monte Carlo Methods with R" from Robert and and Casella (2009)