## ==================================================== ##
## Functions for Multivariate Normal Mixture Estimation ##
##                                                      ##
## Author: Yong Wang (yongwang@auckland.ac.nz)          ##
##         Department of Statistics                     ##
##         University of Auckland                       ##
##         New Zealand                                  ##
## ==================================================== ##

## Input: 
## 
## x    Data
## mix  Mixing distribution
## k    Number of components

## n     Number of observations
## m     Number of dimensions

## CNM without hierarchy

cnmmb = function(data, mix, bwf=1, by, verbose=0, maxit=100,
                 nt=20, tol=1e-4, mahalanobis.tol=0.2,
                 plot=c("null","marginals","density","components","gradient")) {
  if(! is.numeric(data)) data = as.matrix(data)
  plot = match.arg(plot)
  n = nrow(data)
  m = ncol(data)
  Sd = cov(data)
  Sigma = Sd * bwf^2 
  if(missing(by)) by = which.min(colSums(apply(data, 2, duplicated)))
  if(missing(mix) || is.null(mix))
    mix = initial.mvnmix(data, Sigma=chol(Sigma), chol.Sigma=TRUE, by=by)
  else mix = sort(mix, by=by)
  h = prod(diag(mix$Sigma))^(1/m)
  convergence = 1
  ll = logLik(mix, data)
  if(maxit < 1)
    return( structure(list(num.iterations=0, convergence=1, max.gradient=NULL,
                           ll=ll[1], mix=mix, h=h, bwf=bwf),
                      class="cnmmb") )
  tem = tems = 5
  for(i in 1:maxit) {
    mix1 = mix
    ll1 = ll
    if(verbose > 0) print.cnmmbverbose(verbose, i-1, ll, mix)
    if(plot != "null") plot(mix, data, type=plot)
    
    ## EM
    xk = array(data, dim=c(n,m,length(mix$pr)))
    for(iem in 1:tem) {
      logd = outer.dmvn(data, mix$mu, mix$Sigma, mix$chol.Sigma, log=TRUE) +
        rep(log(mix$pr), rep(n,length(mix$pr)))
      ma = matMaxs(logd); pi.d = exp(logd - ma); pi.s = rowSums(pi.d)
      p = pi.d / pi.s
      mix$pr = colMeans(p)
      if(any(j <- mix$pr < 1e-10)) {
        mix$pr = mix$pr[!j]
        mix$pr = mix$pr / sum(mix$pr)
        p = p[,!j,drop=FALSE]
        xk = array(data, dim=c(n,m,length(mix$pr)))
      }
      xkp = sweep(xk, c(1,3), p, "*")
      dim(xkp) = c(n, m*length(mix$pr))
      mu = colSums(xkp)
      dim(mu) = c(m, length(mix$pr))
      mix$mu = t(mu) / mix$pr / n
      xkp = sweep(sweep(xk, 2:3, t(mix$mu)), c(1,3), sqrt(p), "*")
      
      Sigma = 0
      for(j in 1:length(mix$pr)) Sigma = Sigma + crossprod(xkp[,,j])
      
      R = chol(Sigma)  / sqrt(n)
      mix$Sigma = h * R / prod(diag(R))^(1/m)
      mix$chol.Sigma = TRUE
      if(tem > tems) {
        ll.old = ll
        ll = logLik(mix, data)
        if(ll <= ll.old + tol) break
      }
    }
    if(tem <= tems) ll = logLik(mix, data)

    ## Finding new support points
    w = pmax(dmvnmix(data, mix), 1e-100)
    gmix = mvnmix(data, 1/w, mix$Sigma, mix$chol.Sigma, sort=FALSE)
    t = rbind(mix$mu, rmvnmix(nt, gmix))
    k = nrow(t)
    xk = array(data, dim=c(n, m, k))

    for(imem in 1:5) {                # Modal EM
      logd = outer.dmvn(t, gmix$mu, gmix$Sigma, gmix$chol.Sigma, log=TRUE) +
        rep(log(gmix$pr), each=k)
      ma = matMaxs(logd); pi.d = exp(logd - ma); pi.s = rowSums(pi.d)
      p = pi.d / pi.s
      ## t = apply(sweep(xk, c(3,1), p, "*"), 3:2, sum)
      xkp = sweep(xk, c(3,1), p, "*")
      dim(xkp) = c(n, m*k)
      t = colSums(xkp)
      dim(t) = c(m, k)
      t = t(t)
    }
    
    mids = (mix$mu[-1,by] + mix$mu[-length(mix$pr),by]) * 0.5
    o = order(t[,by])
    t = t[o,,drop=FALSE]
    dmix = dmvnmix(data, mix)
    d = outer.dmvn(data, t, mix$Sigma, mix$chol.Sigma)
    g = colSums(d / pmax(dmix,1e-100)) - n
    
    index = indx(t[,by], mids) + 1
    jj = aggregate(g, by=list(group=index), which.max)
    j = match(jj$group, index) + jj$x - 1
    pt = t[j,,drop=FALSE]
    gpt = g[j]

    ## points(pt[,1], pt[,2], col="red")
    
    ## CNM
    dmix = dmvnmix(data, mix)
    mix2 = mvnmix(rbind(mix$mu, pt), c(mix$pr, double(length(pt))), mix$Sigma,
                  mix$chol.Sigma, by=by)
    d = outer.dmvn(data, mix2$mu, mix2$Sigma, mix2$chol.Sigma)
    S = d / pmax(dmix, 1e-100)
    grad = colSums(S)
    r = pnnls(S, 2, sum=1)
    sol = r$x / sum(r$x)
    mix = mix2
    mix$pr = sol
    dea = sum(grad * (mix$pr - mix2$pr)) * .333
    mixt = mix
    alpha = 1
    for(j in 1:50) {
      llt = logLik(mixt, data)
      if( llt[1] >= ll[1] + dea * alpha ) break
      alpha = alpha * .5
      mixt$pr = (1 - alpha) * mix2$pr + alpha * mix$pr
    }
    mix = mixt
    ll = llt
    mix = collapse.mvnmix(mix, mahalanobis.tol, by=by)

    if( ll[1] <= ll1[1] + tol ) {
      if(tem == 5) tem = 100
      if(tem == 100) {convergence=0; break}
    }
    else tem = 5
  }
  if(verbose > 0) print.cnmmbverbose(verbose, i, ll, mix)
  structure(list(num.iterations=i, convergence=convergence,
                 max.gradient=max(gpt), ll=ll[1], mix=mix, h=h, bwf=bwf), 
            class="cnmmb")
}


## hierarchical CNM

hcnmmb = function(data, mix, bwf=1, by, verbose=0, maxit=100,
                 nt=20, tol=1e-4, mahalanobis.tol=0.2,
                 plot=c("null","marginals","density","components","gradient")) {
  if(! is.numeric(data)) data = as.matrix(data)
  plot = match.arg(plot)
  n = nrow(data)
  m = ncol(data)
  Sd = cov(data)
  Sigma = Sd * bwf^2 
  if(missing(by)) by = which.min(colSums(apply(data, 2, duplicated)))
  if(missing(mix) || is.null(mix))
    mix = initial.mvnmix(data, Sigma=chol(Sigma), chol.Sigma=TRUE, by=by)
  else mix = sort(mix, by=by)
  h = prod(diag(mix$Sigma))^(1/m)
  convergence = 1
  ll = logLik(mix, data)
  if(maxit < 1)
    return( structure(list(num.iterations=0, convergence=1, max.gradient=NULL,
                           ll=ll[1], mix=mix, h=h, bwf=bwf),
                      class="cnmmb") )
  tem = tems = 5
  for(i in 1:maxit) {
    ## cat("##### iteration", i, "##### \n")
    mix1 = mix
    ll1 = ll
    if(verbose > 0) print.cnmmbverbose(verbose, i-1, ll, mix)
    if(plot != "null") plot(mix, data, type=plot)
    
    ## EM
    xk = array(data, dim=c(n,m,length(mix$pr)))
    for(iem in 1:tem) {
      logd = outer.dmvn(data, mix$mu, mix$Sigma, mix$chol.Sigma, log=TRUE) +
        rep(log(mix$pr), rep(n,length(mix$pr)))
      ma = matMaxs(logd); pi.d = exp(logd - ma); pi.s = rowSums(pi.d)
      p = pi.d / pi.s
      mix$pr = colMeans(p)
      if(any(j <- mix$pr < 1e-10)) {
        mix$pr = mix$pr[!j]
        mix$pr = mix$pr / sum(mix$pr)
        p = p[,!j,drop=FALSE]
        xk = array(data, dim=c(n,m,length(mix$pr)))
      }
      xkp = sweep(xk, c(1,3), p, "*")
      dim(xkp) = c(n, m*length(mix$pr))
      mu = colSums(xkp)
      dim(mu) = c(m, length(mix$pr))
      mix$mu = t(mu) / mix$pr / n
      xkp = sweep(sweep(xk, 2:3, t(mix$mu)), c(1,3), sqrt(p), "*")
      
      Sigma = 0
      for(j in 1:length(mix$pr)) Sigma = Sigma + crossprod(xkp[,,j])
      
      R = chol(Sigma)  / sqrt(n)
      mix$Sigma = h * R / prod(diag(R))^(1/m)
      mix$chol.Sigma = TRUE
      if(tem > tems) {
        ll.old = ll
        ll = logLik(mix, data)
        if(ll <= ll.old + tol) break
      }
    }
    if(tem <= tems) ll = logLik(mix, data)

    ## Finding new support points
    w = pmax(dmvnmix(data, mix), 1e-100)
    gmix = mvnmix(data, 1/w, mix$Sigma, mix$chol.Sigma, sort=FALSE)
    t = rbind(mix$mu, rmvnmix(nt, gmix))
    k = nrow(t)
    xk = array(data, dim=c(n, m, k))

    for(imem in 1:5) {                # Modal EM
      logd = outer.dmvn(t, gmix$mu, gmix$Sigma, gmix$chol.Sigma, log=TRUE) +
        rep(log(gmix$pr), each=k)
      ma = matMaxs(logd); pi.d = exp(logd - ma); pi.s = rowSums(pi.d)
      p = pi.d / pi.s
      ## t = apply(sweep(xk, c(3,1), p, "*"), 3:2, sum)
      xkp = sweep(xk, c(3,1), p, "*")
      dim(xkp) = c(n, m*k)
      t = colSums(xkp)
      dim(t) = c(m, k)
      t = t(t)
    }
    
    mids = (mix$mu[-1,by] + mix$mu[-length(mix$pr),by]) * 0.5
    o = order(t[,by])
    t = t[o,,drop=FALSE]
    dmix = dmvnmix(data, mix)
    d = outer.dmvn(data, t, mix$Sigma, mix$chol.Sigma)
    g = colSums(d / pmax(dmix,1e-100)) - n
    
    index = indx(t[,by], mids) + 1
    jj = aggregate(g, by=list(group=index), which.max)
    j = match(jj$group, index) + jj$x - 1
    pt = t[j,,drop=FALSE]
    gpt = g[j]

    ## points(pt[,1], pt[,2], col="red")
    
    ## HCNM
    dmix = dmvnmix(data, mix)
    mix2 = mvnmix(rbind(mix$mu, pt), c(mix$pr, double(length(pt))), mix$Sigma,
                  mix$chol.Sigma, by=by)
    D = outer.dmvn(data, mix2$mu, mix2$Sigma, mix2$chol.Sigma)
    
    r = hcnm(D, p0=mix2$pr, w=1, maxit=3)
    mix2$pr = r$pf
    ll = r$ll
    mix = collapse.mvnmix(mix2, mahalanobis.tol, by=by)
    ## mix = collapse.mvnmix(mix2, 0, by=by)
    if( ll[1] <= ll1[1] + tol ) {
      if(tem == 5) tem = 100
      if(tem == 100) {convergence=0; break}
    }
    else tem = 5
  }
  if(verbose > 0) print.cnmmbverbose(verbose, i, ll, mix)
  structure(list(num.iterations=i, convergence=convergence,
                 max.gradient=max(gpt), ll=ll[1], mix=mix, h=h, bwf=bwf), 
            class="cnmmb")
}

# r = cnmmb(iris[,1:4], bwf=.6, plot="n")


## Modal EM

mem = function(t, mix, maxit=5) {
  k = nrow(t)
  m = ncol(t)
  n = nrow(mix$mu)
  xk = array(mix$mu, dim=c(n, m, k))
  for(imem in 1:maxit) {
    logd = outer.dmvn(t, mix$mu, mix$Sigma, mix$chol.Sigma, log=TRUE) +
      rep(log(mix$pr), each=k)
    ma = matMaxs(logd); pi.d = exp(logd - ma); pi.s = rowSums(pi.d)
    p = pi.d / pi.s
    xkp = sweep(xk, c(3,1), p, "*")
    dim(xkp) = c(n, m*k)
    t = colSums(xkp)
    dim(t) = c(m, k)
    t = t(t)
  }
  t
}

plot.cnmmb = function(x, data,
     type=c("marginals","density","components","gradient","pairs"),
     ...) {
  type = match.arg(type)
  plot.mvnmix(x$mix, data, type=type, ...)
  invisible(x)
}

print.cnmmbverbose = function(verbose, i, ll, mix) {
  if(verbose >= 1) cat("## Iteration ", i, ":\nLog-likelihood = ",
                       as.character(signif(ll[1], 16)), "\n", sep="")
  if(verbose >= 2) {cat("Component proportions:\n"); print(mix$pr)}
  if(verbose >= 3) {cat("Component means:\n"); print(mix$mu)}
  if(verbose >= 4) {
    cat("Component covariance matrix:\n");
    if(mix$chol.Sigma) print(crossprod(mix$Sigma))
    else print(mix$Sigma)
  }
}

grad.mvnmix = function(pt, mix, data) {
  dmix = dmvnmix(data, mix)
  d = outer.dmvn(data, pt, mix$Sigma, mix$chol.Sigma)
  colSums(d / dmix) - nrow(data)
}

##

initial.mvnmix = function(data, Sigma, chol.Sigma=FALSE, by=1) {
  n = nrow(data)
  m = ncol(data)
  if(chol.Sigma) R = Sigma
  else R = chol(Sigma)
  pts = data
  xmu = colMeans(data)
  zt = backsolve(R, t(data) - xmu, transpose=T)      # standardized data
  index = 1:n
  d2 = d2c = colSums(zt^2)
  for(k in 1:n) {
    i = d2 > 9      # 16
    index = index[i]        # those not covered yet
    if(length(index) == 0) break
    imin = index[which.min(d2c[index])]
    pts[k,] = zt[,imin] * sqrt(1 + 4 / d2c[imin])   # 9
    d2 = colSums((zt[,index,drop=FALSE]-pts[k,])^2)
  }
  muz = if(k == 1) t(rep(0, m)) else rbind(0, pts[1:(k-1),,drop=FALSE]) # nrow=k
  dsq = (zt[,rep(1:n,k)] - t(muz)[,rep(1:k,each=n)])^2
  dim(dsq) = c(m,n*k)
  d2muz = colSums(dsq)
  dim(d2muz) = c(n, k)
  ## d2muz = apply(dsq, 2:3, sum)
  imuz = apply(d2muz, 1, which.min)
  nj = tabulate(imuz, nbins=k)
  pr = nj / sum(nj)
  mu = t(R %*% t(muz) + xmu)
  rownames(mu) = NULL
  j = pr != 0
  mvnmix(mu[j,,drop=FALSE], pr[j], R, TRUE, by=by)
}

# If it is needed, the Cholesky decomposition should be done before calling
# the function. The user can choose how Sigma is stored. 

mvnmix = function(mu=c(0,0), pr=1, Sigma=diag(nrow=ncol(mu)), chol.Sigma=FALSE,
                  sort=TRUE, by=1) {
  if(is.vector(mu)) mu = t(mu)
  k = nrow(mu)
  m = ncol(mu)
  pr = rep(pr, len=k)
  pr = pr / sum(pr)
  mix = list(mu=mu, pr=pr, Sigma=Sigma, chol.Sigma=chol.Sigma)
  class(mix) = "mvnmix"
  if(sort) sort(mix, by=by)
  else mix
}

sort.mvnmix = function(mix, by=1, decreasing=FALSE) {
  args = as.data.frame(mix$mu[,by,drop=FALSE])
  args$decreasing = decreasing
  o = do.call(order, args)
  mix$mu = mix$mu[o,,drop=FALSE]
  mix$pr = mix$pr[o]
  mix
}

# is.unsorted.mvnmix = function(mix, by=1) is.unsorted(mix$mu[,by])

print.mvnmix = function(mix) {
  class(mix) = NULL
  print(mix)
}

# homoscedastic normal mixture

rmvnmix = function(n, mix=mvnmix(chol.Sigma=TRUE)) {
  if(n == 0) return( matrix(ncol=ncol(mix$mu), nrow=0) )
  m = nrow(mix$mu)
  l = ncol(mix$mu)
  suppressWarnings(i <- sample.int(m, n, prob=mix$pr, replace=TRUE))
  z = rnorm(n*l)
  dim(z) = c(n,l)
  if(mix$chol.Sigma) z %*% mix$Sigma + mix$mu[i,]
  else z %*% chol(mix$Sigma) + mix$mu[i,]
}

# x       Array
# w       Weightes
# d       Dimensions to put weights on

weighted.mean.array = function(x, w, d) {
  k = length(dim(x))
  apply(sweep(x, d, w, "*"), (1:k)[-d], sum) / sum(w)
}

## Use the Mahalanobis distance (with weights?)

collapse.mvnmix = function(mix, tol=0.1, by=1) {
  if( length(mix$pr) == 1 ) return(mix)
  mix = sort.mvnmix(mix, by)
  j = mix$pr != 0
  mix$mu = mix$mu[j,,drop=FALSE]
  mix$pr = mix$pr[j]
  if(length(mix$pr) == 1 || tol <= 0) return(mix)
  
  tol2 = tol * tol
  m = nrow(mix$mu)
  d2 = colSums(backsolve(if(mix$chol.Sigma) mix$Sigma else chol(mix$Sigma),
                         t(mix$mu[rep(1:m, m),] - mix$mu[rep(1:m, each=m),]),
                         transpose=T)^2)
  dim(d2) = c(m, m)               # outer of squared Mahalanobis distance
  d2l = d2[lower.tri(d2)]
  d2i = cbind(d2l, t(combn(m,2)))
  d = d2i[order(d2i[,1]),,drop=FALSE]
  
  mut = mix$mu
  pit = mix$pr
  index = rep(TRUE, m)
  for(i in 1:nrow(d)) {
    if(d[i,1] > tol2) break
    ip = d[i,2:3]                 # indices of the pair
    if(all(index[ip])) {
      mut[ip[1],] = weighted.mean.array(mut[ip,,drop=FALSE], pit[ip], 1)
      pit[ip[1]] = sum(pit[ip])
      index[ip[2]] = FALSE
    }
  }
  mix$mu = mut[index,,drop=FALSE]
  mix$pr = pit[index]
  mix
}

dmvnmix = function(x, mix, log=FALSE) {
  if(is.vector(x)) x = t(x)
  if(any(j <- mix$pr == 0)) {
    mix$mu = mix$mu[!j,,drop=FALSE]
    mix$pr = mix$pr[!j]
  }
  logd = outer.dmvn(x, mix$mu, mix$Sigma, mix$chol.Sigma, log = TRUE) +
    rep(log(mix$pr), rep(nrow(x), length(mix$pr)))
  ma = matMaxs(logd); pid = exp(logd - ma);  pis = rowSums(pid)
  r = ma + log(pis)
  if(!log) r = exp(r)
  ## if(p.ind) attr(r, "p") = pid / pis
  r
}

pmvnmix = function(x, mix, lower.tail=TRUE, log=FALSE) {
  if(is.vector(x)) x = t(x)
  if(any(j <- mix$pr == 0)) {
    mix$mu = mix$mu[!j,,drop=FALSE]
    mix$pr = mix$pr[!j]
  }
  n = nrow(x)                 # sample size
  m = length(mix$pr)          # number of components
  x1 = x[rep(1:n, m),]
  mu1 = mix$mu[rep(1:m, rep(n,m)),]
  Sigma = if(mix$chol.Sigma) mix$Sigma else chol(mix$Sigma)
  z = backsolve(Sigma, t(x1 - mu1), transpose=T)
  logpz = colSums(pnorm(z, lower.tail=lower.tail, log.p=TRUE)) +
    rep(log(mix$pr), rep(n, m))
  dim(logpz) = c(n, m)
  ma = matMaxs(logpz)
  logp = log(rowSums(exp(logpz - ma))) + ma
  if (log) logp
  else exp(logp)
}

marginal.mvnmix = function(mix, margins=1) {
  if(mix$chol.Sigma) mix$Sigma = crossprod(mix$Sigma)
  Sigma = mix$Sigma[margins,margins,drop=FALSE]
  if(mix$chol.Sigma) Sigma = chol(Sigma)
  mvnmix(mix$mu[,margins,drop=FALSE], mix$pr, Sigma, mix$chol.Sigma)
}

logLik.mvnmix = function(mix, x) {    # , p.ind=FALSE
  sum(dmvnmix(x, mix, log=TRUE))
  ## if(all(w == 1)) r = sum(d) else r = sum( w * d )
  ## if(p.ind) attr(r, "p") = attr(d, "p")
  ## sum(d) 
}

plot.mvnmix = function(mix, data=NULL,
     type=c("marginals","density","components","gradient","pairs"),
     len, mfrow, levels, pch=1, col="blue", lty=1, lwd=1, xlim, ylim,
     xlab, ylab, main, cex=1, add=FALSE, ...) {
  type = match.arg(type)
  m = ncol(mix$mu)
  if(missing(main))
    main = switch(type,
                  marginals ="Marginal Densities",
                  density = "Joint Density",
                  components ="Mixture Components",
                  gradient = "Gradient Function", 
                  pairs = "Pairwise Marginal Densities") 
  if(type == "marginals") {
    if(missing(len)) len = 100
    if(missing(mfrow)) {
      ncol = ceiling(sqrt(ncol(mix$mu)))
      mfrow = c(ceiling(m/ncol), ncol)
    }
    if(main == "") par(oma=c(0,0,0,0), mfrow=mfrow)
    else par(oma=c(0,0,4,0), mfrow=mfrow)
    for(j in 1:m) {
      mvj = marginal.mvnmix(mix, j)
      mixj = disc(drop(mvj$mu), mvj$pr)
      xj = npnorm(numeric(0)); namej = NULL 
      if(! is.null(data) && nrow(data) != 0) 
        { xj = npnorm(data[,j]); namej = colnames(data)[j] }
      xlab = if(!is.null(namej)) namej else substitute(x[jj], list(jj = j))
      if(missing(ylab)) ylab = "Density"
      sd = if(mvj$chol.Sigma) drop(mvj$Sigma) else sqrt(drop(mvj$Sigma))
      plot(xj, mixj, sd, len=len, components="null", 
           xlab=xlab, col=col, lty=lty, lwd=lwd, ylab=ylab, 
           cex=cex, cex.axis=cex, cex.lab=cex, cex.main=cex, ...)
    }
    title(main, outer=TRUE)
    par(oma=rep(0,4), mfrow=c(1,1))
    return(invisible(mix))
  }
    
  if(type != "pairs" && m > 2)
    stop(paste0("Can not produce a '", type, "' plot: Data dimension > 2"))
  if(type == "density" && m == 1) {
    plot(mix, data, type="marginals", main=main,
         len=len, pch=pch, col=col, lty=lty, lwd=lwd, xlim=xlim, ylim=ylim,
         xlab=xlab, ylab=ylab, cex=cex, ...)
    return(invisible(mix))
  }
  if(mix$chol.Sigma) Sigma = crossprod(mix$Sigma)
  else Sigma = mix$Sigma
  if(is.null(data)) {
    if(type == "gradient")
      stop("Can not plot the gradient function when 'data' is not provided.\n")
    if(missing(xlim))
      xlim = range(mix$mu[,1]) + c(-3,3) * sqrt(Sigma[1,1])
    if(missing(ylim))
      ylim = range(mix$mu[,2]) + c(-3,3) * sqrt(Sigma[2,2])
    if(missing(xlab)) xlab=expression(x[1])
    if(missing(ylab)) ylab=expression(x[2])
    if(!add)
      plot(0, 0, type="n", xlab=xlab, ylab=ylab, xlim=xlim, ylim=ylim, main=main,
           cex=cex, cex.axis=cex, cex.lab=cex, cex.main=cex,
           ...)
  }
  else {
    if(type != "pairs") {
      if(missing(xlim)) xlim = range(data[,1])
      if(missing(ylim)) ylim = range(data[,2])
      names = colnames(data)
      if(missing(xlab)) 
      { xlab = if(is.null(names[1])) expression(x[1]) else names[1] }
      if(missing(ylab)) 
      { ylab = if(is.null(names[2])) expression(x[2]) else names[2] }
      if(!add)
        plot(data[,1], data[,2], col="grey", pch=pch, xlim=xlim, ylim=ylim,
             xlab=xlab, ylab=ylab, main=main,
             cex=cex, cex.axis=cex, cex.lab=cex, cex.main=cex,
             ...)
      else points(data[,1], data[,2], col="grey", pch=pch)
    }
  }
  switch(type,
         components = {
           if(missing(len)) len = 100
           k = length(mix$pr)
           S = solve(Sigma)
           d = 3
           for(j in 1:k) {
             dp = d * sqrt(mix$pr[j])    ## area-proportional
             plot.ellipse(S, mu=mix$mu[j,], d=dp, add=TRUE, col=col,
                          lty=lty, lwd=lwd, len=len)
           }
           points(mix$mu[,1], mix$mu[,2], col=col, pch=8, cex=1.3)
         },
         density = {
           if(missing(len)) len = 100
           x1 =  seq(xlim[1], xlim[2], len=len)
           x2 = seq(ylim[1], ylim[2], len=len)
           z = matrix(dmvnmix(expand.grid(x1, x2), mix=mix), nrow=len)
           r = range(z); dr = diff(r)
           if(missing(levels) || length(levels) == 0) levels = 5
           ## levels = signif(seq(r[2]*.95, r[1]+0.03*dr, len=5), 2)
           if(length(levels) == 1)
             levels = signif(r[2] * .8 * .3^(0:(levels-1)), 2)
##             levels = signif(quantile(r, (levels:1-.8)/(levels-.6), len=levels), 2)
           ## levels = signif(seq(r[2]*.95, r[1]+0.001*dr, len=levels), 2)
           contour(x1, x2, z, levels=levels, col=col, lty=lty, lwd=lwd,
                   labcex=0.6*cex, add=TRUE)
         },
         gradient = {
           if(missing(len)) len = 30
           x1 =  seq(xlim[1], xlim[2], len=len)
           x2 = seq(ylim[1], ylim[2], len=len)
           z = matrix(grad.mvnmix(expand.grid(x1, x2), mix, data), nrow=len)
           if(missing(levels)) {
             r = range(z); dr = diff(r)
             levels = signif(seq(r[2]-.02*dr, r[2]-.8*dr, len=5), 2)
             ## levels = signif(r[2] * .8 * .3^(0:4), 2)
           }
           contour(x1, x2, z, levels=levels, col=col, lty=lty, lwd=lwd, add=TRUE)
           points(mix$mu[,1], mix$mu[,2], col=col, pch=8, cex=1.3)
         },
         pairs = {
           if(missing(levels) || length(levels) == 0) levels = 5
           if(missing(len)) len = 100
           m = ncol(data)
           names = colnames(data)
           if (is.null(names)) names = paste0("X",1:m)
           if(main == "")
             opar = par(mfrow=c(m, m), mar=rep(0, 4), oma=c(4.1, 4.1, 4.1, 4.1))
           else
             opar = par(mfrow=c(m, m), mar=rep(0, 4), oma=c(4.1, 4.1, 6.1, 4.1))
           for(i in 1:m)
             for(j in 1:m) {
               xlim = range(data[,j], na.rm=TRUE)
               if (i == j) {
                 plot.new()
                 plot.window(xlim = xlim, ylim = xlim)
                 text(mean(xlim), mean(xlim), names[j], cex=cex*2, font=1)
                 box()
               }
               else {
                 if(i > j) {
                   mix2 = marginal.mvnmix(mix, c(j,i))
                   plot.mvnmix(mix2, data[,c(j,i)], type="d", main="", cex=cex,
                               col=col, lty=lty, lwd=lwd, xaxt="n", yaxt="n")
                 }
                 else {
                   plot(data[,j], data[,i], main="", cex=cex, col="grey",
                        xaxt="n", yaxt="n")
                 }
               }
               if (i %% m == 0 && j %% 2 == 1) axis(1, cex.axis=cex)
               if (j %% m == 1 && i %% 2 == 0) axis(2, cex.axis=cex)
               if (i %% m == 1 && j %% 2 == 0) axis(3, cex.axis=cex)
               if (j %% m == 0 && i %% 2 == 1) axis(4, cex.axis=cex)
             }
           if(main != "") title(main, outer=TRUE, cex.main=cex*2)
           par(opar) } )
  invisible(mix)
}

# Use the Cholesky decomposition

plot.ellipse = function(A, mu=0, d=3, add=FALSE, solid=FALSE, len=100,
           col="red", lty=2, ...) {
  if(any(d <= 0)) return()
  Qt = chol(A)
  j = order(d, dec=TRUE)
  col = rep(col,len=length(d))[j]
  for(i in 1:length(d))  {
    y0 = seq(-d[i],d[i],len=len)[-c(1,len)]
    y1 = c(-d[i],y0,d[i],rev(y0),-d[i])
    yd = sqrt(d[i]^2 - y0^2)
    y2 = c(0,-yd,0,rev(yd),0)
    x = sweep(cbind(y1,y2) %*% solve(t(Qt)), 2, mu, "+")   # backsolve???
    if(solid) polygon(x[,1], x[,2], col=col[i], border="red", lty=1, ...)
    else {
      if(add || i != 1) lines(x[,1], x[,2], col=col[i], lty=lty,...)
      else plot(x[,1], x[,2], type="l", col=col[i], lty=lty, ...)
    }
  }
}



