library("watson")
library("rgl")
library("grid")
library("viridis")
library("parallel")
library("ggplot2")
library("mistr")
options(rgl.printRglwidget = TRUE)


# Figure 1

# Replication code for generating the measured times stored in "timesCAPIC10.Rda" and 
# "timesCAPIC10000.Rda" is available as supplementary material at:
# https://www.tandfonline.com/doi/suppl/10.1080/10618600.2024.2416521
# Please note that the measured times are influenced by hardware specifications and the current 
# load on the machine, so exact replication of the timing results may not be possible.
load("timesCAPIC10.Rda")
kappas <- c(-100, -50, -10, -1, 1, 10, 50, 100)
ds <- c( 3, 5, 10, 20, 50, 100, 200, 1000)
A <- matrix(results[,3], length(ds), length(kappas), byrow = TRUE)
colnames(A) <- kappas
rownames(A) <- ds
B <- matrix(results[,4], length(ds), length(kappas), byrow = TRUE)
colnames(B) <- kappas
rownames(B) <- ds
G = (B - A)/pmin(B, A)
rownames(G) <- ds

load("timesCAPIC10000.Rda")
A <- matrix(results[,3], length(ds), length(kappas), byrow = TRUE)
colnames(A) <- kappas
rownames(A) <- ds
B <- matrix(results[,4], length(ds), length(kappas), byrow = TRUE)
colnames(B) <- kappas
rownames(B) <- ds
G2 = (B - A)/pmin(B, A)
rownames(G2) <- ds

# The plot function used in this code is a modified version of the existing `corrplot` function 
# from the `corrplot` package. It has been adjusted to allow plotting on scales other than [-1, 1] 
# and to include annotation rectangles used in the visualization.
# Note: This customization is specific to this visualization and is not of general interest,
# so the function is kept in an external R file for this purpose only.
source(file = "modded_corrplot.R")
par(mfrow=c(1, 2))
corrplot2(t(G), method = "color", cl.align.text="l",  cl.offset = 0.2, 
          title = "             n = 10", addCoef.col = "black")
text(4.5, 9.15, bquote(d == .("")), col = 2)
text(-0.3, 4.5, bquote(kappa == .("")), col = 2)
corrplot2(t(G2), method = "color", cl.align.text="l",  cl.offset = 0.2, 
          title = "               n = 10000", addCoef.col = "black")
text(4.5, 9.15, bquote(d == .("")), col = 2)
text(-0.3, 4.5, bquote(kappa == .("")), col = 2)
par(mfrow=c(1, 1))


# Sampling examples
set.seed(1)
sample1 <- rmwat(n = 2000, weights = c(0.5, 0.5), kappa = c(20, 20),
                 mu = matrix(c(1, 1, 1, -1, 1, 1), nrow = 3))
sample2 <- rmwat(n = 2000, weights = c(0.5, 0.5), kappa = c(-200, -200), 
                 mu = matrix(c(1, 1, 1, -1, 1, 1), nrow = 3))

# Figure 2. Code used for plotting the samples

open3d()
view3d(theta = 195, phi = 10, zoom = 0.5)
bg3d(color = "white")
points3d(sample1[, 1], sample1[, 2], sample1[, 3], ylim = c(-1, 1),
         col = id(sample1), xlim = c(-1, 1), zlim = c(-1, 1),
         xlab = "x", ylab = "y", zlab = "z", alpha = 0.8)
spheres3d(x = 0, y = 0, z = 0, radius = 0.98, col = "green", alpha = 0.6, back = "lines")
segments3d(c(-1.5, 1.5), c(0, 0), c(0, 0), color = "black")
segments3d(c(0, 0), c(-1.5, 1.5), c(0, 0), color = "blue")
segments3d(c(0, 0), c(0, 0), c(-1.5, 1.5), color = "red")

open3d()
view3d(theta = 195, phi = 10, zoom = 0.5)
bg3d(color = "white")
points3d(sample2[, 1], sample2[, 2], sample2[, 3], ylim = c(-1, 1),
         col = id(sample2), xlim = c(-1, 1), zlim = c(-1, 1),
         xlab = "x", ylab = "y", zlab = "z", alpha = 0.8)
spheres3d(x = 0, y = 0, z = 0, radius = 0.98, col = "green", alpha = 0.6, back = "lines")
segments3d(c(-1.5, 1.5), c(0, 0), c(0, 0), color = "black")
segments3d(c(0, 0), c(-1.5, 1.5), c(0, 0), color = "blue")
segments3d(c(0, 0), c(0, 0), c(-1.5, 1.5), color = "red")

# Household example
data("household", package = "HSAUR3")
x <- household[, c("housing", "food", "service")]
gender <- household$gender

set.seed(1)
wat <- lapply(1:4, function(K) watson(x, k = K))
sapply(wat, BIC)

set.seed(1)
(watt <- watson(x, k = 6, minweight = 0.15, nruns = 100))
table(predict(watt), gender)

# Figure 3
x <- x/sqrt(rowSums(x^2))
gender <- household$gender
mu <- rbind(t(watson(x[gender == "female", ], k = 1)$mu), t(watson(x[gender == "male", ], k = 1)$mu))
rownames(mu) <- c("female", "male")

open3d()
view3d(theta = 62, phi = 25, zoom = 0.6)
bg3d(color = "white")
points3d(x[, 1], x[, 2], x[, 3], ylim = c(-1, 1), col = as.integer(c(gender)), xlim = c(-1, 1), 
         zlim = c(-1, 1), alpha = 0.8)
points3d(mu[, 1], mu[, 2], mu[, 3], ylim = c(-1, 1), col = c(1, 2), xlim = c(-1, 1), zlim = c(-1, 1), 
        alpha = 0.7, size = 10)
spheres3d(x = 0, y = 0, z = 0, radius = 0.98, col = "green", alpha = 0.6, polygon_offset = 1)
spheres3d(x = 0, y = 0, z = 0, radius = 0.98, col = "grey", front = "lines", back = "lines")
segments3d(c(-0.98, 0.98), c(0, 0), c(0, 0), color = "black")
segments3d(c(0, 0), c(-0.98, 0.98), c(0, 0), color = "blue")
segments3d(c(0, 0), c(0, 0), c(-0.98, 0.98), color = "red")
title3d(main = substitute(paste(bold("Known group membership"))), level = -7.5, line = 0.7)

w2 <- wat[[2]]
mu2 <- t(w2$mu_matrix)
open3d()
view3d(theta = 62, phi = 25, zoom = 0.6)
bg3d(color = "white")
points3d(x[, 1], x[, 2], x[, 3], ylim = c(-1, 1), col = -predict(w2) + 3, xlim = c(-1, 1),
         zlim = c(-1, 1), alpha = 0.8)
points3d(mu2[, 1], mu2[, 2], mu2[, 3], ylim = c(-1, 1), col = c(2, 1), xlim = c(-1, 1), zlim = c(-1, 1),
         alpha = 0.7, size = 10)
spheres3d(x = 0, y = 0, z = 0, radius = 0.98, col = "green", alpha = 0.6)
spheres3d(x = 0, y = 0, z = 0, radius = 0.975, col = "grey", front = "lines", back = "lines")
segments3d(c(-0.98, 0.98), c(0, 0), c(0, 0), color = "black")
segments3d(c(0, 0), c(-0.98, 0.98), c(0, 0), color = "blue")
segments3d(c(0, 0), c(0, 0), c(-0.98, 0.98), color = "red")
title3d(main = substitute(paste(bold("Mixture of Watsons with K = 2"))), level = -7.5, line = 0.7)

w3 <- wat[[3]]
mu3 <- t(w3$mu_matrix)
open3d()
view3d(theta = 62, phi = 25, zoom = 0.6)
bg3d(color = "white")
points3d(x[, 1], x[, 2], x[, 3], ylim = c(-1, 1), col = ifelse(-predict(w3) + 4 == 3, 4, -predict(w3) + 4), 
         xlim = c(-1, 1), zlim = c(-1, 1), alpha = 0.8)
points3d(mu3[, 1], mu3[, 2], mu3[, 3], ylim = c(-1, 1), col = c(4, 2, 1), xlim = c(-1, 1),
         zlim = c(-1, 1), alpha = 0.7, size = 10)
spheres3d(x = 0, y = 0, z = 0, radius = 0.98, col = "green", alpha = 0.6)
spheres3d(x = 0, y = 0, z = 0, radius = 0.975, col = "grey", front = "lines", back = "lines")
segments3d(c(-0.98, 0.98), c(0, 0), c(0, 0), color = "black")
segments3d(c(0, 0), c(-0.98, 0.98), c(0, 0), color = "blue")
segments3d(c(0, 0), c(0, 0), c(-0.98, 0.98), color = "red")
title3d(main = substitute(paste(bold("Mixture of Watsons with K = 3"))), level = -7.5, line = 0.7)


w4 <- watt
mu4 <- t(w4$mu_matrix)
open3d()
view3d(theta = 62, phi = 25, zoom = 0.6)
bg3d(color = "white")
points3d(x[, 1], x[, 2], x[, 3], ylim = c(-1, 1), col = predict(w4), xlim = c(-1, 1), zlim = c(-1, 1),
         alpha = 0.8)
points3d(mu4[, 1], mu4[, 2], mu4[, 3], ylim = c(-1, 1), col = c(1, 2), xlim = c(-1, 1), zlim = c(-1, 1),
         alpha = 0.7, size = 10)
spheres3d(x = 0, y = 0, z = 0, radius = 0.98, col = "green", alpha = 0.6)
spheres3d(x = 0, y = 0, z = 0, radius = 0.975, col = "grey", front = "lines", back = "lines")
segments3d(c(-0.98, 0.98), c(0, 0), c(0, 0), color = "black")
segments3d(c(0, 0), c(-0.98, 0.98), c(0, 0), color = "blue")
segments3d(c(0, 0), c(0, 0), c(-0.98, 0.98), color = "red")
title3d(main = substitute(paste(bold("Mixture of Watsons with"))), level = -7.5, line = 1.8)
title3d(main = substitute(paste(bold("minweight = 0.15"))), level = -7.5, line = 0.7)


# Simulation Study
set.seed(1)
d <- rmwat(n = 2000, weights = c(0.1, 0.3, 0.2, 0.2, 0.2),
           kappa = c(-200, -200, 30, 50, 100), 
           mu = matrix(c(1, 1, 1, -1, 1, 1, -1, -1, -1, 0, 1, -1, 1, 0, 0), nrow = 3))
set.seed(1)
model <- watson(d, 7, minweight = 0.02, nruns = 20)
model

# Figure 5 Plot

open3d()
view3d(theta = 195, phi = 10, zoom = 0.5)
bg3d(color = "white")
points3d(d[, 1], d[, 2], d[, 3], ylim = c(-1, 1), col = id(d), xlim = c(-1, 1),
         zlim = c(-1, 1), xlab = "x", ylab = "y", zlab = "z", alpha = 0.8)
spheres3d(x = 0, y = 0, z = 0, radius = 0.98, col = "yellow", alpha = 0.6, back = "lines")
segments3d(c(-1.2, 1.2), c(0, 0), c(0, 0), color = "black")
segments3d(c(0, 0), c(-1.2, 1.2), c(0, 0), color = "blue")
segments3d(c(0, 0), c(0, 0), c(-1.2, 1.2), color = "red")

colo <- predict(model) * (-1) + 6
colo <- sapply(colo, function(x) if (x == 3) {
                                    1
                                 } else if (x == 1) {
                                    2
                                 } else if (x == 2) {
                                    3
                                 } else x)  #matching colors
open3d()
view3d(theta = 195, phi = 10, zoom = 0.5)
bg3d(color = "white")
points3d(d[, 1], d[, 2], d[, 3], ylim = c(-1, 1), col = colo,
         xlim = c(-1, 1), zlim = c(-1, 1), xlab = "x", ylab = "y",
         zlab = "z", alpha = 0.8)
spheres3d(x = 0, y = 0, z = 0, radius = 0.98, col = "yellow",
          alpha = 0.6, back = "lines")
segments3d(c(-1.2, 1.2), c(0, 0), c(0, 0), color = "black")
segments3d(c(0, 0), c(-1.2, 1.2), c(0, 0), color = "blue")
segments3d(c(0, 0), c(0, 0), c(-1.2, 1.2), color = "red")

# New Zealand earthquake data example Preprocessing in appendix
dataa <- rbind(cbind(read.csv("cantyearlyPTaxes.csv", header = T), type = "CE"),
               cbind(read.csv("cantylatePTaxes.csv", header = T), type = "CL"), 
               cbind(read.csv("lewisPTaxes.csv", header = T), type = "L"))
classif <- as.factor(dataa$type)
dataa <- dataa[, -4]

ind <- dataa$theta >= pi/2
dataa[!ind, ] <- cbind(dataa[!ind, 1] + pi, pi - dataa[!ind, 2], pi -
                         dataa[!ind, 3])
new_angles <- as.data.frame(cbind(strike = dataa$phi - pi/2,
                                  dip = pi - dataa$theta, rake = dataa$psi + pi/2))

b <- cbind(-sin(new_angles$rake) * cos(new_angles$strike) +
             cos(new_angles$rake) * cos(new_angles$dip) *
             sin(new_angles$strike), sin(new_angles$rake) *
             sin(new_angles$strike) + cos(new_angles$rake) *
             cos(new_angles$dip) * cos(new_angles$strike),
           cos(new_angles$rake) * sin(new_angles$dip))
save(b, classif, file = "null.RData")

## New Zealand example
load("null.RData")
gg <- watson(b, ids = classif)
gg

one <- watson(b, ids = rep("one", length(classif)))
one

set.seed(1)
B <- 10000
samples <- sapply(1:B, function(x) {
  sample1 <- rmwat(132, 1, one$kappa_vector, one$mu_matrix)
  model3 <- watson(sample1, ids = classif)
  model1 <- watson(sample1, ids = rep("a", length(classif)))
  logLik(model3) - logLik(model1)
})

sum(samples > logLik(gg) - logLik(one))/B


## Comparison of two Christchurch clusters in appendix
set.seed(1)
c <- b[classif == "CE" | classif == "CL", ]
classifi <- classif[classif == "CE" | classif == "CL"]
gg2 <- watson(c, ids = classifi)
one2 <- watson(c, ids = rep("one", length(classifi)))

samples2 <- sapply(1:B, function(x) {
  sample1 <- rmwat(100, 1, one2$kappa_vector, one2$mu_matrix)
  model3 <- watson(sample1, ids = classifi)
  model1 <- watson(sample1, ids = rep("one", length(classifi)))
  logLik(model3) - logLik(model1)
})

sum(samples2 > logLik(gg2) - logLik(one2))/B

## Depth image clustering example

num <- 451
pasnum <- sprintf("%06d", num)
depth <- read.csv(paste0("depthcsv/depth_", pasnum, ".csv"),
                  header = F)
norm <- read.csv(paste0("surface_normals_csv/surface_normals_",
                        pasnum, ".csv"), header = F)
rgbpic <- read.csv(paste0("rgbcsv/rgb_", pasnum, ".csv"),
                   header = F)

pdf(paste0("pic_", pasnum, ".pdf"), width = 6, height = 4)

col <- rgb(rgbpic$V1, rgbpic$V2, rgbpic$V3, maxColorValue = 255)
dim(col) <- c(427, 561)
grid.newpage()
grid.raster(col, interpolate = FALSE)

r <- as.matrix((255/(max(depth - min(depth)))) * (depth - min(depth)))
g <- matrix(0, 427, 561)
b <- matrix(0, 427, 561)
col <- rgb(r, g, b, maxColorValue = 255)
dim(col) <- dim(r)
grid.newpage()
grid.raster(col, interpolate = FALSE)

a <- as.matrix(norm)
watrun <- function(i, a) {
  w <- watson(a, k = i, minweight = 0.05, nruns = 100, verbose = T)
  w$data <- NULL
  w
}
watrunhard <- function(i, a) {
  w <- watson(a, k = i, E = "hardmax", minweight = 0.05, nruns = 100, verbose = T)
  w$data <- NULL
  w
}

cl <- makeCluster(3, outfile = "progress.txt")
clusterExport(cl = cl, list("watson"))
parallel::clusterSetRNGStream(cl = cl, 1)
B <- parLapply(cl, 2:7, watrun, a)
stopCluster(cl)

BIC_S <- unlist(lapply(B, function(w) {
  i <- length(w$kappa_vector)
  unif <- which(abs(w$kappa_vector) == min(abs(w$kappa_vector)))
  mag <- magma(i)
  e <- mag[unif]
  mag[unif] <- mag[1]
  mag[1] <- e
  r <- matrix(mag[predict(w)], 427, 561)
  grid.newpage()
  grid.raster(r, interpolate = FALSE)
  BIC(w)
}))

cl <- makeCluster(3, outfile = "progress.txt")
clusterExport(cl = cl, list("watson"))
parallel::clusterSetRNGStream(cl = cl, 1)
B <- parLapply(cl, 2:7, watrunhard, a)
stopCluster(cl)

BIC_H <- unlist(lapply(B, function(w) {
  i <- length(w$kappa_vector)
  unif <- which(abs(w$kappa_vector) == min(abs(w$kappa_vector)))
  mag <- magma(i)
  e <- mag[unif]
  mag[unif] <- mag[1]
  mag[1] <- e
  r <- matrix(mag[predict(w)], 427, 561)
  grid.newpage()
  grid.raster(r, interpolate = FALSE)
  BIC(w)
}))

plot(2:7, BIC_S, ylim = range(c(BIC_S, BIC_H)), ylab = "BIC Scores",
     xlab = "number of initial clusters k")
points(2:7, BIC_H, pch = 4, col = 2)
legend("topright", pch = c(1, 4), col = c(1, 2),
       legend = c("Soft Assignment", "Hard Assignment"))
dev.off()

# Note for Figure 8:
# The clustered and generated images can be obtained by running the depth image clustering section 
# with a different image number set by the "num" variable (e.g., "num <- 450").
# This allows the generation of clustered surface normals for different depth images as shown in Figure 8.



# Code to reproduce Figure 4.  This figure was generated originally
# for the Mathematics of Computation paper (Sablica and Hornik
# (2022)) before the package was created.  The individual bounds are
# not exported from the C++ routine and thus this cannot be
# reproduced only using the watson package.  The files 'a05b50' and
# 'a995b100' were generated in Mathematica, see Generate_g.nb

iratio_gneg <- function(a, b, z, N, change = FALSE) {
  start1 <- (2 * (a + N))/(b + N - z + sqrt((z - b - N)^2 + 4 * (a + N) * z))
  start2 <- 1 - (2 * (b - a))/((b + N) - 1 + z + sqrt(-4 * ((b - a) + 1) * z + (z + (b + N + 1))^2))
  s <- c(start1, start2)
  if (N > 0) {
    for (i in (N - 1):0) {
      s <- (a + i)/((b + i) - z + z * s)
    }
  }
  if (change)
    s <- 1 - s
  data.frame(y = s, type = as.factor(rep(c("l", "u"), each = length(z))))
}
g <- function(a, b, z, N) {
  m <- data.frame(y = numeric(length(z) * 2), type = factor(c("l", "u")))
  z2 <- rep(z, 2)
  m[z2 == 0, ] <- data.frame(y = rep(a/b, sum(z2 == 0)),
                             type = factor(rep(c("l", "u"), each = sum(z == 0))))
  m[z2 < 0, ] <- iratio_gneg(a, b, z[z < 0], N)
  m[z2 > 0, ] <- iratio_gneg(b - a, b, -z[z > 0], N, change = TRUE)
  m
}

x <- seq(-100, 200, 0.1)
a <- 0.5
b <- 50
N <- 3
d <- do.call("rbind", lapply(c(0, 1, 5), function(i) data.frame(x = x, g(a, b, x, i))))
d <- cbind(d, N = as.factor(rep(c(0, 1, 5), each = 2 * length(x))))
a05b50 <- read.csv(file = "a05b50.csv", header = FALSE)
colnames(a05b50) <- "y"
d2 <- cbind(x, a05b50, type = as.factor("l"), N = as.factor("true"))
d3 <- rbind(d, d2)

p1 <- ggplot(d3) + geom_line(aes(x, y, col = N, lty = type), linewidth = 0.8) +
                   theme(legend.position = "top") + 
                   ggplot2::labs(x = NULL, y = NULL, title = expression("g(a=0.5, b=50)"))

d <- do.call("rbind", lapply(c(0, 1, 5, 10, 30), function(i) data.frame(x = x, g(a, b, x, i))))
d <- cbind(d, N = as.factor(rep(c(0, 1, 5, 10, 30), each = 2 * length(x))))
d$y <- (d$y - a05b50$y)/a05b50$y
p2 <- ggplot(d) + geom_line(aes(x, y, col = N, lty = type), linewidth = 0.8) +
                  theme(legend.position = "top") +
                  ggplot2::labs(x = NULL, y = NULL, title = expression("g(a=0.5, b=50) - relative differences"))



x <- seq(-500, 100, 0.1)
a <- 99.5
b <- 100
N <- 50
d <- do.call("rbind", lapply(c(0, 1, 5), function(i) data.frame(x = x, g(a, b, x, i))))
d <- cbind(d, N = as.factor(rep(c(0, 1, 5), each = 2 * length(x))))
a995b100 <- read.csv(file = "a995b100.csv", header = FALSE)
colnames(a995b100) <- "y"
d2 <- cbind(x, a995b100, type = as.factor("l"), N = as.factor("true"))
d3 <- rbind(d, d2)
p3 <- ggplot(d3) + geom_line(aes(x, y, col = N, lty = type), linewidth = 0.8) +
                   theme(legend.position = "none") + 
                   ggplot2::labs(x = NULL, y = NULL, title = expression("g(a=99.5, b=100)"))

d <- do.call("rbind", lapply(c(0, 1, 5, 10, 30), function(i) data.frame(x = x, g(a, b, x, i))))
d <- cbind(d, N = as.factor(rep(c(0, 1, 5, 10, 30), each = 2 * length(x))))
d$y <- (d$y - a995b100$y)/a995b100$y
p4 <- ggplot(d) + geom_line(aes(x, y, col = N, lty = type), linewidth = 0.8) +
                  theme(legend.position = "none") +
                  ggplot2::labs(x = NULL, y = NULL, title = expression("g(a=99.5, b=100) - relative differences"))
mistr:::multiplot(p1, p3, p2, p4, cols = 2)


