message("Beginning simulations...")

rm(list=ls())
pks = c("bizicount", "doRNG", "parallel", "doParallel", "RhpcBLASctl")
invisible(sapply(pks, library, character.only=TRUE))

# b1, b2 are parameter vectors for margins 1 and 2 respectively (can be scalar, too)
gen = function(n, 
               b1, 
               b2, 
               psi1,
               psi2,
               dep){
  
  k1 = length(b1)
  k2 = length(b2)
  X1 = cbind(1, matrix(rbinom(n*(k1-1), 1, .5), ncol = k1-1))
  X2 = cbind(1, matrix(rexp(n*(k2-1), 3), ncol = k2-1))
  lam1 = exp(X1%*%b1)
  lam2 = exp(X2%*%b2)

  norm_vars = MASS::mvrnorm(n, mu = c(0,0), Sigma=matrix(c(1,dep,dep,1), ncol=2))
  U = pnorm(norm_vars)
  y1 =  qzip(U[,1], lam1, psi1)
  y2 =  qzip(U[,2], lam2, psi2)
  
  dat = data.frame(X1 = X1[,-1], X2 = X2[,-1], y1, y2, lam1, lam2, psi1, psi2)
  return(dat)
}

estimate = function(...){
  gen_args = list(...)
  data = gen(...)
  
  if(min(data$y1) > 0 || min(data$y2) > 0)
    return(NULL) # must have zeros in data, otherwise zi parameter is -Inf theoretically
  
  f1 = y1 ~ X1.1 + X1.2 | 1
  f2 = y2 ~ X2.1 + X2.2 | 1
  o1 = y1 ~ X1.1 + X1.2
  o2 = y2 ~ X2.1 + X2.2
  
  # Correct model
  biv1 = tryCatch(
    bizicount(
      f1,
      f2,
      data = data,
      iterlim = 500,
      margins = c("zip", "zip"),
      cop="gaus",
      stepmax = 25
    ),
    error = function(e)
      NULL
  )
  
  if (is.null(biv1))
    return(NULL)
  
  # Omit zero inflation in biv mod
  biv2 = tryCatch(
    bizicount(
      o1,
      o2,
      data = data,
      iterlim = 500,
      margins = c("pois", "pois"),
      keep=TRUE,
      cop = "gaus",
      stepmax = 25
    ),
    error = function(e)
      NULL
    )
  
  if (is.null(biv2))
    return(NULL)
  
  # Model as univariate, margin 1
  uni1 = tryCatch(
    zic.reg(
      f1,
      data = data,
      iterlim=500
    ),
    error = function(e)
      NULL
  )
  
  if (is.null(uni1))
    return(NULL)
  
  
  # Model as univariate, margin 2
  uni2 = tryCatch(
    zic.reg(
      f2,
      data = data,
      iterlim=500
    ),
    error = function(e)
      NULL
  )
  
  if (is.null(uni2))
    return(NULL)
  
  # Cleanup
  allmods = list(
    biv1 = biv1,
    biv2 = biv2,
    uni1 = uni1,
    uni2 = uni2
  )
  coefs = lapply(allmods, function(x) x$coef)
  ses   = lapply(allmods, function(x) x$se)
 
  return(c(
    n = gen_args$n,
    psi1_tru = gen_args$psi1,
    psi2_tru = gen_args$psi2,
    dep_true = gen_args$dep,
    est = unlist(coefs),
    se = unlist(ses),
    conv1 = biv1$conv,
    conv2 = biv2$conv,
    ll.biv1 = logLik(biv1),
    ll.biv2 = logLik(biv2),
    ll.uni1 = logLik(uni1),
    ll.uni2 = logLik(uni2)
  ))
  
}

# -- Data generating processes, monte carlo grid
n = 500
dep = c(.15, .85)

b1 = c(1, 3.25, -2.3)
b2 = c(2, -1.75, 3.5)

psi1 = psi2 = c(.1, .6)

grid = expand.grid(
  n = n,
  psi1 = psi1,
  psi2 = psi2,
  dep = dep
)

nsims = 500
grid = do.call(
  rbind.data.frame, 
  replicate(nsims, grid, simplify = F)
  )

# -- Setup parallel backend, disable matrix multi threading
set.seed(987465)
ncores = parallel::detectCores() - 1
cl = makeCluster(ncores, type=if(grepl("unix", .Platform$OS.type)) "FORK" else "PSOCK")
registerDoParallel(cl)
mcoptions <- list(preschedule = T)
RhpcBLASctl::blas_set_num_threads(1)
RhpcBLASctl::omp_set_num_threads(1)

# -- Simulate 
start = Sys.time()
res = foreach(
  i = seq_len(nrow(grid)),
  .packages = c("bizicount", "RhpcBLASctl"),
  .export = ls(),
  .errorhandling = "pass",
  .verbose=F,
  .options.multicore = mcoptions
  ) %dorng%{
    
    RhpcBLASctl::blas_set_num_threads(1)
    RhpcBLASctl::omp_set_num_threads(1)
    
    gridrow = grid[i,]   
    
    #cat("Starting on Grid Row ", i, "\n")
    return(
      estimate(
        n = gridrow$n,
        b1 = b1,
        b2 = b2,
        psi1 = gridrow$psi1,
        psi2 = gridrow$psi2,
        dep = gridrow$dep
      )
    )
    
  }

end = Sys.time()

cat("Total time:\n")
end - start
# -- Cleanup, save output

stopCluster(cl)
stopImplicitCluster()
session = sessionInfo()
save(res, session, file="output_montes_small.RData")

message("\n\n---Monte carlo simulations complete. Data available at ./output_montes_small.RData---\n\n")
