This post looks at implementing estimating a Dirichlet Process mixture via a finite approximation to the stick-breaking construction of a DP. There is a directly analogy to finite mixture models. The model here is (I apologize for the notation…eventually I will figure out how to incorporate LaTeX):

  • y_i ~ N(mu_i, sigma_i^2 )
  • mu_i, sigma_i^2 ~ P
  • P ~ DP(a P0)

where a is the concentration parameter of the Dirichlet process and P0 is the base measure. In this case, the base measure is a Normal-Inverse Gamma(m0,t0, a, b). The stick-breaking construction allows for an alternative representation of the model via

  • y_i ~ N(\mu_i, sigma_i^2)
  • mu_i, \sigma_i^2 ~ \sum_{h=1}^\infty \pi_h \delta_{mu_h^* , sigma_h^{2*}}
  • pi ~ stick-breaking(a)
  • mu_h^* , sigma_h^{2*} ~ Normal-Inverse Gamma(m0,t0, a, b)

One way to estimate this model is to use a finite approximation (up to H) to the infinite integral in this stick-breaking construction.

The original version of this code is from here and modified by me. I’m sure any mistakes are my own.

The first part of this code just has the function to build the stick-breaking construction.

suppressMessages(library(mixtools))
## Error in library(mixtools): there is no package called 'mixtools'
suppressMessages(library(Hmisc))
suppressMessages(library(DPpackage))
## Error in library(DPpackage): there is no package called 'DPpackage'
suppressMessages(library(weights))
## Error in library(weights): there is no package called 'weights'
# Stick-breaking realizations
calc_pi = function(v) {
  n = length(v)
  pi = numeric(n)
  cumv = cumprod(1-v)
  pi[1] = v[1]
  for (i in 2:n) pi[i] = v[i]*cumv[i-1]
  pi
}

Next we simulate the data use the mixture of normals simulator from the ‘mixtools’ package.

# Dirichlet process mixture

# Data
set.seed(1)
n = 500
truth = data.frame(pi = c(.1,.5,.4),
             mu = c(-3,0,3),
             sigma =  sqrt(c(.5,.75,1)))
f =  function(x) {
  out = numeric(length(x))
  for (i in 1:length(truth$pi)) out = out+truth$pi[i]*dnorm(x,truth$mu[i],truth$sigma[i])
  out
}
y = rnormmix(n,truth$pi,truth$mu,truth$sigma)
## Error in rnormmix(n, truth$pi, truth$mu, truth$sigma): could not find function "rnormmix"

Now we set up the prior, initial values for the MCMC, and objects to record the results.

# Prior
alpha = 1    # DP scale parameter

# theta|sigma2 ~ N(m0,sigma2/t0)
# sigma2 ~ IG(a,b)
m0    = 0    # mean of mu all clusters
t0    = .001 # Prec of mu all clusters	
a = b = .01  # Hyperparms for sigma all clusters
H     = 20	 # Max number of clusters in block DP


# Inits
ns   = rep(0,H)       # Mixing weights and number of subjects per cluster
v    = rep(1/H,H)     # Conditional weights -- pr(c_i=h|c_i not in l<h)
v[H] = 1              # Apply DP truncation to H clusters

pi = calc_pi(rbeta(H,1,alpha))
mu   = rnorm(H)       # Cluster-specific means
sigma2 = rep(1,H)	
p = tmp2 = matrix(0,n,H) # p[i,h] = conditional prob that subject i belongs to cluster h
 

# Record
nsim = 1000

grid = seq(min(y),max(y),length=500)
## Error in seq(min(y), max(y), length = 500): object 'y' not found
keep = list(v      = matrix(0, nsim, H),
            mu     = matrix(0, nsim, H),
            sigma2 = matrix(0, nsim, H),
            n      = matrix(0, nsim, H),
            pi     = matrix(0, nsim, H),
            S      = matrix(0, nsim, n),
            y      = array(0,dim=c(nsim, length(grid), H)))

The MCMC involves iterating through the following steps:

  • Sample cluster indicators S
  • Sampling mixture probabilities pi and mixture parameters mu and sigma^2.

This is analogous to finite mixture models except in the updating of the mixture probabilities we need to use the stick-breaking construction.

Without further ado, we run the MCMC>

# Run MCMC
for (i in 1:nsim) {
  # Update S, cluster indicator
  for (h in 1:H) tmp2[,h] = pi[h]*dnorm(y,mu[h],sqrt(sigma2[h]))
  p = tmp2/apply(tmp2,1,sum)
  S = rMultinom(p,1)

  # Size of each cluster, including zeros
  tS = table(S)
  ns = rep(0,H)
  ns[1:length(tS)] = tS

 # Update v and pi
  for (h in 1:(H-1)) v[h] = rbeta(1,1+ns[h],alpha+sum(ns[(h+1):H]))
  pi = calc_pi(v)
   
 # Update mu and sigma2
 for (h in 1:H) {
   var       = 1/(t0+ns[h]/sigma2[h])
   m         = var*(t0*m0+sum(y[S==h])/sigma2[h]) 
   mu[h]     = rnorm(1, m, sqrt(var))
   sigma2[h] = 1/rgamma(1,a+ns[h]/2,b+sum((y[S==h]-mu[h])^2)/2)
 }

  # Record parameters
  keep$v[i,]     = v
  keep$mu[i,]    = mu
  keep$sigma[i,] = sqrt(sigma2)
  keep$n[i,]     = ns     # Number per cluster
  keep$pi[i,]    = pi
  keep$S[i,]     = S
  for (h in 1:H) keep$y[i,,h] = pi[h]*dnorm(grid,mu[h],sqrt(sigma2[h]))
}
## Error in dnorm(y, mu[h], sqrt(sigma2[h])): object 'y' not found

Next we calculate some posterior summaries and compare to the truth. We only used point estimates here, but we could have easily provided uncertainty estimates as well.

# Summaries
iters = 501:nsim

# Avg number per cluster 	
nh = table(round(apply(keep$S[iters,],2,mean)))		

# Avg proportion per cluster (another way)
mprop<-round(apply(keep$n[iters,],2,mean)[which(nh>0)]/n,2)
	
# Cluster means and standard deviations
mmu<-round(apply(keep$mu[iters,],2,mean)[which(nh>0)],2)
msigma<-round(apply(keep$sigma[iters,],2,mean)[which(nh>0)],2)

est = data.frame(pi=mprop,
                 mu=mmu,
                 sigma=msigma)
est
##   pi mu sigma
## 1  0  0     0
truth
##    pi mu     sigma
## 1 0.1 -3 0.7071068
## 2 0.5  0 0.8660254
## 3 0.4  3 1.0000000

We can also estimate the unknown density of the data via estimating the approximation on a finite grid of y values. This will be used below to compare to the truth as well as to the ‘DPdensity’ function from the DPpackage.

# Calculate density on a grid of y values
Ytmp<-matrix(0,nsim,length(grid))
for (i in 1:nsim) {
 for (j in 1:length(grid)) Ytmp[i,j]<-sum(keep$y[i,j,])
}

yhat<-apply(Ytmp,2,mean)

Here is the same analysis using the DPpackage.

# DPpackage
mcmc <- list(nburn=1000, nsave=1000, nskip=1, ndisplay=100)

prior1 <- list(alpha=1,m1=rep(0,1),psiinv1=diag(0.5,1),nu1=4,
               tau1=1,tau2=100)
fit <- DPdensity(y=y,prior=prior1,mcmc=mcmc,
                 state=state, status=TRUE)
## Error in DPdensity(y = y, prior = prior1, mcmc = mcmc, state = state, : could not find function "DPdensity"

Now we can compare our fit and the DPdensity fit to the truth (as well as the data).

# Compare point estimates
#plot(fit,ask=F)
hist(y,breaks=20,freq=F,main="Dirichlet Process Fit to Three-Component Mixture",ylim=c(0,.25))
## Error in hist(y, breaks = 20, freq = F, main = "Dirichlet Process Fit to Three-Component Mixture", : object 'y' not found
lines(grid,yhat,type="l",lwd=2,lty=3,col="darkgreen")
## Error in as.double(x): cannot coerce type 'closure' to vector of type 'double'
lines(fit$x1,fit$dens,type="l",col="darkred",lwd=2,lty=2)
## Error in lines(fit$x1, fit$dens, type = "l", col = "darkred", lwd = 2, : object 'fit' not found
#lines(density(y),type="l",col="blue4",lwd=2,lty=2)
curve(f, col="black", lwd=2, add=TRUE)
## Error in plot.xy(xy.coords(x, y), type = type, ...): plot.new has not been called yet
legend("topright",col=c("darkgreen","darkred","black"),lty=c(3,2,1),
       legend=c("Estimated","DPpackage","True Density"),lwd=2,cex=.9)
## Error in strwidth(legend, units = "user", cex = cex, font = text.font): plot.new has not been called yet


blog comments powered by Disqus