# supervised group lasso
library(methods)
library(Matrix)

SGLasso2 <- function(X, y, hc = NULL, dist = "euclidian", method = "ward.D2", K=10,s = c("lambda.1se","lambda.min"), plot = FALSE)
{
  # Dimension
  n = nrow(X)
  p = ncol(X)
  
  #
  yb = y - mean(y)
  Xb = scale(X, center = TRUE, scale = TRUE)
  
  #### clustering
  temps = rep(NA,4)
  
  if(is.null(hc))
  {
    t1 = proc.time()
    require(fastcluster)
    
    if(dist == "euclidian")
    {
      X2 <- scale(X, center = TRUE, scale = FALSE)
      X2 = scale(X2, center = FALSE, scale = sqrt(colSums(X2^2)/n))
      
      D = dist(t(X))
    }else{
      
      # correlation based distance
      D = as.dist(1 - abs(cor(X)))  
    }
    
    hc = fastcluster::hclust(D, method = method)
    t2 = proc.time()
    
    temps[1] = t2[3] - t1[3]
  }
  
  nbGroup <- p-which.max(diff(hc$height))
  
  # partition
  group = cutree(hc, nbGroup)
  
  
  t1 = proc.time()
  ####  lasso for each group
  foldcv <- rep(1:K, ceiling(n/K))[1:n]
  foldcv = foldcv[sample(n)]
  
  # reduc dim
  require(glmnet)
  groupGL = c()
  varGL = c()
  
  for(i in 1:nbGroup)
  {
    groupi = which(group == i)
    if(length(groupi) == 1)
    {
      groupGL = c(groupGL, i)
      varGL = c(varGL, groupi)
    }
    else
    {
      res = cv.glmnet(Xb[, groupi], yb, foldid = foldcv, intercept = FALSE)
      varToKeep = which(coef(res, s)[-1]!=0)
      # Pour toujours avoir au moins une variable
      if(length(varToKeep) == 0)
      {
        varToKeep = which(res$glmnet.fit$beta[,which(res$glmnet.fit$df!=0)[1]]!=0)
      }
      groupGL = c(groupGL, rep(i, length(varToKeep)))
      varGL = c(varGL, groupi[varToKeep])
    } 
  } # end for lasso group  
  t2=proc.time()
  
  temps[2] = (t2-t1)[3]
  
  #### group-lasso part
  require(gglasso)
  t1=proc.time()
  resgg = gglasso(Xb[,varGL], y, group = groupGL, intercept = FALSE)
  t2 = proc.time()
  res = cv.gglasso(Xb[,varGL], y, group = groupGL, nfolds = K, intercept = FALSE)
  t3=proc.time()
  beta = coef(res, s)[-1]
  
  
  temps[3] = (t2-t1)[3]
  temps[4] = (t3-t2)[3]
  
  return(list(sgl = res, sel = varGL[which(beta!=0)], group = group, groupGL = groupGL, varGL = varGL, temps = temps))
}
