library("BayesX")
library("sp")
library("survival")
library("colorspace")
library("gridExtra")
library("parallel")
options(ncores = 8)
datapath <- file.path(getwd(), "data")
RNGkind("L'Ecuyer-CMRG")

library("tramME")
TMB::openmp(n = 4, DLL = "tramME")
opt <- optim_control(method = "nlminb",
  iter.max = 300, eval.max = 400,
  rel.tol = sqrt(.Machine$double.eps))

## ========== Section 4.1: Child mortality in Nigeria

nmap <- read.bnd(file.path(datapath, "nigeria.bnd"))
nm <- bnd2sp(nmap)
nigeria <- read.table(file.path(datapath, "nigeria_cox.raw"),
                      header = TRUE, sep = " ")

adj.mat <- bnd2gra(nmap) ## adjacency matrix
adj.mat <- diag(diag(adj.mat)) - as.matrix(adj.mat)
nb <- apply(adj.mat, 1, function(x) which(x == 1)) ## neighbors list

nigeria$district <- factor(nigeria$district)
nigeria$breastfeed <- factor(nigeria$breastfeedduration == 1,
                             levels = c(FALSE, TRUE),
                             labels = c("no", "yes"))
nigeria$education <- factor(nigeria$education,
                            levels = c(0, 1),
                            labels = c("primary+", "no"))
nigeria$delivery <- factor(nigeria$placeofdelivery,
                           levels = c(0, 1),
                           labels = c("home", "hospital"))
nigeria$sex <- factor(nigeria$sex, levels = c(0, 1),
                      labels = c("female", "male"))
nigeria$heapingleft <- as.numeric(nigeria$heapingleft)
nigeria$heapingright <- as.numeric(nigeria$heapingright)
nigeria$trunctime <- as.numeric(nigeria$trunctime)

## ---- Figure 2
m <- 0:54
par(mar = c(4, 4, 1, 1), las = 1, cex = 0.9)
tbl <- table(m[findInterval(nigeria$survtime[nigeria$censind == 1], m*30)])
plot(0, type = "n", xlim = c(0, 54), ylim = c(0, 300), xaxt = "n",
     panel.first = grid(nx = NA, ny = NULL),
     xlab = "time (months)", ylab = "frequency")
lines(tbl, lwd = 8, col = "grey", type = "h", lend = 1)
axis(1, at = seq(0, 54, by = 6), labels = seq(0, 54, by = 6))

nigeria$y <- with(nigeria, Resp(
  tleft = ifelse(trunctime > 0, trunctime, NA),
  cleft = heapingleft,
  cright = ifelse(heapingleft == heapingright & censindinterval < 1,
                  Inf, heapingright),
  bounds = c(0, Inf))
)
nigeria$y[c(1, 2, 3, 13, 14, 21, 22)]

m_sp <- CoxphME(y ~ s(district, bs = "mrf", k = 20, xt = list(nb = nb)),
  data = nigeria, control = opt, bounds = c(0, Inf), support = c(1, 1400),
  log_first = TRUE, order = 8)
summary(m_sp)

sm_sp <- smooth_terms(m_sp)
head(sm_sp[[1]])

## ---- Figure 3
ci <- sm_sp[[1]][, 2] + qnorm(0.975) * sm_sp[[1]][, 3] %o% c(-1, 1)
sig <- apply(ci > 0, 1, function(x) if (all(x)) 1 else if (any(x)) 0 else -1)

cols <- colorspace::diverging_hcl(128, palette = "BlueRed3")
am <- SpatialPolygonsDataFrame(nm, data = data.frame(sm = sm_sp[[1]][, 2]))
p1 <- spplot(am, "sm",
             col.regions = cols,
             cuts = 127,
             ## at = seq(-0.4, 0.4, length.out = 127),
             colorkey = list(space = "bottom", height = 0.5),
             par.settings = list(axis.line = list(col = 'transparent'),
                                 fontsize=list(text = 10)))
am <- SpatialPolygonsDataFrame(nm, data = data.frame(sig = sig))
p2 <- spplot(am, "sig",
             col.regions = c(cols[20], "white", cols[108]),
             cuts = 2,
             colorkey = NULL,
             par.settings = list(axis.line = list(col = 'transparent'),
                                 fontsize=list(text = 10)))
gridExtra::grid.arrange(p1, p2, ncol = 2)

m_ph <- update(m_sp, . ~ . + breastfeed + sex + education + delivery +
  s(ageatbirth, bs = "ps") + s(birthorder, bs = "ps"))
summary(m_ph)



## ---- Figure 4
par(mar = c(4, 4, 1, 1), las = 1, cex = 0.9)
sm_ph <- smooth_terms(m_ph)
plot(sm_ph[2:3], panel.first = grid(), ylim = c(-2, 2))

res <- residuals(m_ph)

(outl <- model.frame(m_ph)[which.min(res), ])

predict(m_ph, type = "survivor", newdata = outl, q = 1506)

predict(m_ph, type = "quantile", newdata = outl, p = c(0.1, 0.5, 0.9))

mm <- model.matrix(m_ph, data = outl)
b <- coef(m_ph, complete = TRUE)
Xm <- cbind(mm$Yr, mm$X, t(as.matrix(mm$Zt)))
vns <- sapply(names(b), function(n) all.vars(str2lang(n))[1])
lch <- sapply(split(Xm*b, vns), sum)
exp(lch)

## ---- Figure 5
par(mfrow = c(1, 2), mar = c(4, 4, 1, 1), las = 1, cex = 0.9)
plot(res, pch = 20, col = grey(0.5, 0.3), ylab = "Martingale residuals",
     panel.first = grid())
outl <- which.min(res)
points(outl, res[outl], pch = 1, cex = 2, col = 2, lwd = 3)

si <- nigeria$householdsize[as.numeric(rownames(m_ph$data))]
plot(0, type = "n",
     xlim = range(si), ylim = range(res),
     xlab = "Household size", ylab = "Martingale residuals",
     panel.first = grid())
rm <- tapply(res, si, mean)
stripchart(res ~ si, method = "jitter", pch = 20,
           at = as.numeric(names(rm)),
           col = grey(0.5, 0.3), vertical = TRUE,
           add = TRUE)
points(as.numeric(names(rm)), rm, col = 2, pch = 18, cex = 1.5)

m_tv <- CoxphME(y | education  ~ breastfeed + sex + delivery +
  s(district, bs = "mrf", k = 20, xt = list(nb = nb)) +
  s(ageatbirth, bs = "ps") + s(birthorder, bs = "ps"),
  data = nigeria, control = opt,
  bounds = c(0, Inf), support = c(1, 1400),
  log_first = TRUE, order = 8)
summary(m_tv)



## ---- Figure 6
par(mfrow = c(1, 2), mar = c(4, 4, 1, 1), las = 1)
nd <- model.frame(m_tv)[rep(9, 2), ]
nd$education <- factor(c("no", "primary+"), levels = c("primary+", "no"),
  labels = c("primary+", "no"))
cb_lch <- confband(m_tv, newdata = nd,
  q = seq(1, 1000, length.out = 100), baseline_only = TRUE)
cb_sur <- confband(m_tv, newdata = nd, type = "survivor",
  q = seq(1, 1000, length.out = 100), baseline_only = TRUE)
plot(cb_lch, single_plot = TRUE, lty = 1, lwd = 2,
  panel.first = grid(), xlim = c(0, 1000), ylim = c(-2, 1),
  ylab = "h(y | education)")
plot(cb_sur, single_plot = TRUE, lty = 1, lwd = 2,
  panel.first = grid(), xlim = c(0, 1000), ylim = c(0, 1),
  ylab = "Survivor")
legend("topright", c("No education", "At least primary education"),
       col = c(1, 2), lwd = 2, lty = 1, bty = "n")

## ========== Section 4.2: Burn victim recovery

burn <- read.csv(file.path(datapath, "burn_data.csv"))
burn$EQindex <- as.numeric(burn$EQindex)
burn$Study <- factor(burn$Study, levels = 1:10)
burn$Patient <- factor(as.numeric(burn$Patient))
burn$Age <- as.numeric(burn$Age)
burn$TBSA <- as.numeric(burn$TBSA)
burn$LOS <- as.numeric(burn$LOS)
burn$Gender <- factor(burn$Gender)
burn$EQvas <- as.numeric(burn$EQvas)
burn$EQ_SF <- factor(burn$EQ_SF)

## ---- Figure 7
par(cex = 0.8, las = 1, mar = c(4, 4, 1, 1))
plot(ecdf(burn$EQindex), verticals = TRUE, pch = NA, lwd = 2,
     panel.first = grid(), main = NULL,
     ylab = "Probability", xlab = "EQindex")

burn$EQi <- with(burn, as.Surv(R(EQindex, as.R.interval = TRUE)))
head(burn$EQi)

m_po <- ColrME(EQi ~ s(time) + s(TBSA) + s(LOS) + s(Age) + Gender +
  (1 | Study:Patient) + (time | Study), data = burn, order = 8,
  bounds = c(-Inf, 1), support = c(-0.1, 1), control = opt)
summary(m_po)



## ---- Figure 8
par(mar = c(4, 4, 1, 1), las = 1, cex = 0.9)
plot(smooth_terms(m_po), panel.first = grid())

mpredict <- function(obj, newdata,
  scale = c("distribution", "density", "survivor"),
  ndraws = 500, antithetic = FALSE, return_draws = FALSE,
  ncpus = 1L, ...) {
  scale <- match.arg(scale)
  if (antithetic) ndraws <- floor(ndraws / 2)
  re <- simulate(obj, type = "ranef", newdata = newdata, nsim = ndraws, ...)
  if (antithetic) re <- c(re, lapply(re, `-`))
  ndraws <- length(re)
  FUN <- function(r) {
    predict(obj, newdata = newdata, type = scale, ranef = r, ...)
  }
  if (ncpus > 1) pr <- parallel::mclapply(re, FUN, mc.cores = ncpus)
  else pr <- lapply(re, FUN)
  if (return_draws) return(pr)
  rns <- rownames(pr[[1]])
  cns <- colnames(pr[[1]])
  pr <- unlist(pr)
  d <- length(pr) / (nrow(newdata) * ndraws)
  pr <- array(pr, dim = c(d, nrow(newdata), ndraws))
  mp <- apply(pr, c(1, 2), mean)
  mcv <- apply(pr, c(1, 2), var) / ndraws
  if (antithetic) mcv <- mcv + apply(pr, c(1, 2), function(x) {
    cov(head(x, ndraws/2), tail(x, ndraws/2)) / (2*ndraws)
  })
  rownames(mp) <- rownames(mcv) <- rns
  colnames(mp) <- colnames(mcv) <- cns
  attr(mp, "mc.se") <- sqrt(mcv)
  mp
}

m_lm <- LmME(EQindex ~ s(time) + s(TBSA) + s(LOS) + s(Age) + Gender +
  (1 | Study:Patient) + (time | Study), data = burn, control = opt)

model.frame(m_lm)[1, ]

nd <- model.frame(m_lm)[rep(1, 100), ]
nd$time <- seq(0, 36, length.out = 100)
pr_lm <- mpredict(m_lm, nd, scale = "survivor", q = c(0.8, 0.9),
  ndraws = 500, antithetic = TRUE, seed = 100)
pr_po <- mpredict(m_po, nd, scale = "survivor", q = c(0.8, 0.9),
  ndraws = 500, antithetic = TRUE, seed = 100)

nd2 <- model.frame(m_lm)[1, ]
nd2$time <- 12
cdf_lm <- mpredict(m_lm, nd2, scale = "distribution",
  q = seq(-0.1, 1, length.out = 100),
  ndraws = 500, antithetic = TRUE, seed = 100)
cdf_po <- mpredict(m_po, nd2, scale = "distribution",
  q = seq(-0.1, 1, length.out = 100),
  ndraws = 500, antithetic = TRUE, seed = 100)

## ---- Figure 9
par(mfrow = c(1, 2), mar = c(4, 4, 1, 1), las = 1, cex = 0.9)
matplot(nd$time, t(pr_lm), col = 1, lwd = 2, type = "l", lty = c(1, 2),
        ylim = c(0, 1), ylab = "Probability", xlab = "Time (months)",
        panel.first = grid())
matlines(nd$time, t(pr_po), col = 2, lwd = 2, type = "l", lty = c(1, 2))
nm <- paste(rep(c("LmME, ", "ColrME,"), each = 2),
            "EQindex >", rep(c(0.8, 0.9), 2))
legend("topright", nm, col = c(1, 1, 2, 2), lty = c(1, 2, 1, 2),
       lwd = 2, bty = "n", cex = 0.9)
abline(v = 12, lwd = 3, col = grey(0.5, 0.5))

matplot(as.numeric(rownames(cdf_lm)), cdf_lm, type = "l",
        col = 1, lty = 1, lwd = 2, ylim = c(0, 1), ylab = "Probability",
        xlab = "EQindex", panel.first = grid())
matlines(as.numeric(rownames(cdf_po)), cdf_po,
         col = 2, lty = 1, lwd = 2)
legend("topleft", c("LmME", "ColrME"), col = c(1, 2), lty = 1,
       lwd = 2, cex = 0.9, bty = "n")
abline(v = c(0.8, 0.9), lwd = 3, col = grey(0.5, 0.5))

## ========== Appendix: Timing comparisons with glmmTMB
if (!file.exists(file.path(datapath, "timings.rda"))) {
  source("v114i11-timings.R")
}
load(file.path(datapath, "timings.rda"))

# sim_lmer <- function(N = 10, K = 20, b = c(251, 10), sd = c(26, 25, 6)) {
#   x <- runif(N * K, 0, 9)
#   y <- drop(cbind(1, x) %*% b) + rep(rnorm(K, sd = sd[2]), each = N) +
#     rep(rnorm(K, sd = sd[3]), each = N) * x + rnorm(N * K, sd = sd[1])
#   data.frame(x = x, y = y, g = factor(rep(seq(K), each = N)))
# }

# library("glmmTMB")
# m1 <- glmmTMB(y ~ x + (x | g), data = df)

# library("tramME")
# m2 <- LmME(y ~ x + (x | g), data = df)
# m3 <- BoxCoxME(y ~ x + (x | g), data = df)

## ---- Figure 10
rr <- do.call("rbind", rt1)
rr$size <- 20 * rep(Ns, each = 40 * 3)
class(rr) <- class(rr)[-1]
rr$time <- rr$time / 1e+09

cp <- colorspace::qualitative_hcl(3, "Pastel 1")
cols <- ifelse(levels(rr$expr) == "glmmTMB", cp[1],
        ifelse(levels(rr$expr) == "LmME", cp[2], cp[3]))

par(mfrow = c(2, 1), las = 1, mar = c(4, 4, 1, 1), cex = 1)
## Pl1
rr <- do.call("rbind", rt1)
rr$size <- 20 * rep(Ns, each = 40 * 3)
class(rr) <- class(rr)[-1]
rr$time <- rr$time / 1e+09
boxplot(time ~ expr + size, data = rr, log = "y", col = cols,
        xaxt = "n", xlab = "", ylab = "")
grid(nx = NA, ny = NULL)
par(new = TRUE)
boxplot(time ~ expr + size, data = rr, log = "y", col = cols,
        xaxt = "n", ylab = "Time (seconds, log-scale)",
        xlab = "Group sizes")
abline(v = 3.5 + seq(0, length(Ns)-2)*3, col = "grey90", lty = "dotted")
axis(1, at = 2 + seq(0, length(Ns)-1)*3, labels = Ns, tick = FALSE)
legend("topleft", c("glmmTMB", "LmME", "BoxCoxME"),
       pch = 15, col = cp, pt.cex = 2, bty = "n")
## Pl2
rr <- do.call("rbind", rt2)
rr$size <- 10 * rep(Ks, each = 40 * 3)
class(rr) <- class(rr)[-1]
rr$time <- rr$time / 1e+09
boxplot(time ~ expr + size, data = rr, log = "y", col = cols,
        xaxt = "n", xlab = "", ylab = "")
grid(nx = NA, ny = NULL)
par(new = TRUE)
boxplot(time ~ expr + size, data = rr, log = "y", col = cols,
        xaxt = "n", ylab = "Time (seconds, log-scale)",
        xlab = "Number of random effects (groups)")
abline(v = 3.5 + seq(0, length(Ks)-2)*3, col = "grey90", lty = "dotted")
axis(1, at = 2 + seq(0, length(Ks)-1)*3, labels = Ks, tick = FALSE)
