# install.packages("hmmTMB")
library("hmmTMB")


###########################################################################
## Replicate Section 3  ------
## of "hmmTMB: Hidden Markov Models with Flexible Covariate Effects in R"
###########################################################################

set.seed(34634526)
n <- 10
dat <- data.frame(ID = 1,
                  Z1 = cumsum(rnorm(n, 0, 5)),
                  Z2 = rpois(n, lambda = 4),
                  cov1 = cumsum(rnorm(n, 0, 0.1)))
head(dat)

###########################################################################
## Replicate the petrel movement analysis presented in Section 7  ------
## of "hmmTMB: Hidden Markov Models with Flexible Covariate Effects in R"
###########################################################################

#########################
## 1. Data preparation ##
#########################
# This section is not included in the manuscript
library(moveHMM)
library(sp)

# DATA REFERENCE: Descamps S, Tarroux A, Cherel Y, Delord K, Godø OR, 
# Kato A, Krafft BA, Lorentsen S, Ropert-Coudert Y, Skaret G, Varpe Ø. 2016. 
# Data from: At-sea distribution and prey selection of Antarctic petrels 
# and commercial krill fisheries. Movebank Data Repository. 
# https://doi.org/10.5441/001/1.q4gn4q56
data_file <- paste0("At-sea distribution Antarctic Petrel, Antarctica 2012 ",
                    "(data from Descamps et al. 2016)-gps.csv")
data_all <- read.csv(data_file)[, c(13, 3, 4, 5)] |>
  setNames(c("ID", "time", "lon", "lat"))
data_all$time <- as.POSIXct(data_all$time, tz = "MST")

# Function to compute first-order differences for grouped data
diff_by_ID <- function(x, ID, ...) {
  # Indices of first and last value of each group
  n <- length(ID)
  i0 <- which(ID[-1] != ID[-n])
  i_first <- c(1, i0 + 1)
  i_last <- c(i0, n)
  
  # First-order differences
  dx <- rep(NA, n)
  dx[-i_last] <- difftime(time1 = x[-i_first], time2 = x[-i_last], ...)
  return(dx)
}

# Keep only a few tracks for this example (excluding tracks that
# have unusually long intervals)
dtimes <- diff_by_ID(data_all$time, data_all$ID)
keep_ids <- setdiff(unique(data_all$ID), 
                    unique(data_all$ID[which(dtimes > 30)]))[1:10]
data <- subset(data_all, ID %in% keep_ids)
data <- data[with(data, order(ID, time)),]

# Define centre for each track as first observation
i0 <- c(1, which(data$ID[-1] != data$ID[-nrow(data)]) + 1)
centres <- data[i0, c("ID", "lon", "lat")]
data$centre_lon <- rep(centres$lon, rle(data$ID)$lengths)
data$centre_lat <- rep(centres$lat, rle(data$ID)$lengths)

# Add distance to centre as covariate (based on sp for great circle distance)
data$d2c <- sapply(1:nrow(data), function(i) {
  spDistsN1(pts = matrix(as.numeric(data[i, c("lon", "lat")]), ncol = 2),
            pt = c(data$centre_lon[i], data$centre_lat[i]),
            longlat = TRUE)
})
# Remove unnecessary columns
data$centre_lon <- NULL
data$centre_lat <- NULL

# Derive step lengths and turning angles using moveHMM
movehmm_data <- prepData(trackData = data, 
                         coordNames = c("lon", "lat"), 
                         type = "LL")
data$step <- movehmm_data$step
data$angle <- movehmm_data$angle

# Replace zero step length to very small number because it's overkill
# to use a zero-inflated distribution just for one zero observation
set.seed(2855)
wh_zero <- which(data$step == 0)
data$step[wh_zero] <- runif(length(wh_zero),
                            min = 0, 
                            max = min(data$step[-wh_zero], na.rm = TRUE))

# Shorten track names
data$ID <- factor(data$ID)
levels(data$ID) <- paste0("PET-", LETTERS[1:length(unique(data$ID))])

######################
## 2. Data analysis ##
######################
# Load package
library("hmmTMB")

# Transform ID to factor for random effects
data$ID <- factor(data$ID)
head(data)

# Initial parameters
step_mean0 <- c(1, 6, 20)
step_sd0 <- c(1, 5, 10)
angle_mean0 <- c(0, 0, 0)
angle_rho0 <- c(0.8, 0.8, 0.9)
par0 <- list(step = list(mean = step_mean0, sd = step_sd0),
             angle = list(mu = angle_mean0, rho = angle_rho0))

# Observation distributions
dists <- list(step = "gamma2", angle = "wrpcauchy")

# Create Observation object
obs <- Observation$new(data = data, 
                       dists = dists, 
                       n_states = 3,
                       par = par0)

# Model formulas
f <- "~ s(ID, bs = 're') + s(d2c, k = 10, bs = 'cs')"
tpm_structure <- matrix(c(".",   f, "~1",
                          f,   ".",    f,
                          "~1",  f,  "."), 
                        ncol = 3, byrow = TRUE)

# Initial transition probabilities
tpm0 <- matrix(c(0.9, 0.1, 0,
                 0.1, 0.8, 0.1,
                 0, 0.1, 0.9),
               ncol = 3, byrow = TRUE)

# Create MarkovChain object
hid <- MarkovChain$new(n_states = 3, 
                       formula = tpm_structure,
                       data = data, 
                       tpm = tpm0, 
                       initial_state = "stationary")

# List of fixed parameters
fixpar <- list(obs = c("angle.mu.state1.(Intercept)" = NA,
                       "angle.mu.state2.(Intercept)" = NA,
                       "angle.mu.state3.(Intercept)" = NA),
               hid = c("S1>S3.(Intercept)" = NA,
                       "S3>S1.(Intercept)" = NA))

# Create HMM object
hmm <- HMM$new(obs = obs, hid = hid, fixpar = fixpar)
hmm$fit(silent = TRUE)

# Plot tracks coloured by state
hmm$plot_ts("lon", "lat") +
  coord_map("mercator") +
  geom_point(size = 0.3) +
  labs(x = "longitude", y = "latitude")

# State-dependent step length distributions
hmm$plot_dist("step") + 
  coord_cartesian(ylim = c(0, 0.45)) +
  theme(legend.position = c(0.8, 0.8)) +
  labs(x = "step (km)")

# State-dependent turning angle distributions
hmm$plot_dist("angle") +
  coord_cartesian(ylim = c(0, 2.2)) +
  theme(legend.position = "none") +
  scale_x_continuous(breaks = seq(-pi, pi, by = pi/2), 
                     labels = expression(-pi, -pi/2, 0, pi/2, pi))

# Transition prob Pr(3 -> 2)
hmm$plot(what = "tpm", var = "d2c", i = 3, j = 2) +
  labs(x = "distance to centre (km)")

# Stationary state probabilities
hmm$plot(what = "delta", var = "d2c") +
  theme(legend.position = "top", legend.margin = margin(c(0, 0, -10, 0))) +
  labs(title = NULL, x = "distance to centre (km)")

# Stationary state probabilities by ID
hmm$plot(what = "delta", var = "ID", covs = list(d2c = 1500))

# "hid" gives the component for the hidden state model
hmm$sd_re()$hid

###########################################################################
## Replicate the energy price analysis presented in Appendix B  ------
## of "hmmTMB: Hidden Markov Models with Flexible Covariate Effects in R"
###########################################################################
library("hmmTMB")
data("energy", package = "MSwM")

###########################################
## Model specification and model fitting ##
###########################################
# Create hidden state model
hid <- MarkovChain$new(data = energy, n_states = 2)

# List of observation distributions
dists <- list(Price = "norm")
# List of initial parameters
par0 <- list(Price = list(mean = c(3, 6), sd = c(1, 1)))
# List of formulas
f <- list(Price = list(mean = ~ s(EurDol, k = 10, bs = "cs"),
                       sd = ~ poly(EurDol, 3)))

# Create observation model
obs <- Observation$new(data = energy,
                       n_states = 2,
                       dists = dists,
                       par = par0,
                       formulas = f)

# Fit model
hmm <- HMM$new(hid = hid, obs = obs)
hmm$fit(silent = TRUE)

#############
## Results ##
#############
# Get most likely state sequence for plotting
energy$viterbi <- factor(paste0("State ", hmm$viterbi()))

# Plot mean price in each state, with data points
hmm$plot(what = "obspar", var = "EurDol", i = "Price.mean") + 
  geom_point(aes(x = EurDol, y = Price, fill = viterbi, col = viterbi), 
             data = energy, alpha = 0.3) +
  theme(legend.position = "none")

# Plot price standard deviation in each state
hmm$plot(what = "obspar", var = "EurDol", i = "Price.sd") +
  theme(legend.position.inside = c(0.3, 0.7))

###################################
## State-dependent distributions ##
###################################
# Get state-dependent parameters for a few values of EurDol
EurDol <- seq(0.65, 1.15, by = 0.1)
newdata <- data.frame(EurDol = EurDol)
par <- hmm$predict(what = "obspar", newdata = newdata)

# Weights for state-dependent distributions
w <- table(hmm$viterbi())/nrow(energy)

# Grid of energy prices (x axis)
grid <- seq(min(energy$Price), max(energy$Price), length = 100)

# For each value of EurDol, compute state-dependent distributions
pdf_ls <- lapply(1:dim(par)[3], function(i) {
  p <- par[,,i]
  pdf1 <- w[1] * dnorm(grid, mean = p["Price.mean", "state 1"], 
                       sd = p["Price.sd", "state 1"])
  pdf2 <- w[2] * dnorm(grid, mean = p["Price.mean", "state 2"], 
                       sd = p["Price.sd", "state 2"])
  res <- data.frame(price = grid, pdf = c(pdf1, pdf2), 
                    state = factor(rep(1:2, each = length(grid))),
                    eurdol = paste0("EurDol = ", EurDol[i]))
  return(res)
})
# Turn into dataframe for ggplot
pdf_df <- do.call(rbind, pdf_ls)

ggplot(pdf_df, aes(price, pdf)) + 
  facet_wrap("eurdol") +
  geom_histogram(aes(x = Price, y=after_stat(density)), bins = 20,
                 col = "white", bg = "lightgrey", data = energy) +
  geom_line(aes(col = state)) + 
  theme_bw() +
  labs(x = "Price", y = NULL) +
  scale_color_manual(values = hmmTMB:::hmmTMB_cols)

###########################################################################
## Replicate the human activity analysis presented in Appendix C  ------
## of "hmmTMB: Hidden Markov Models with Flexible Covariate Effects in R"
###########################################################################

library("hmmTMB")

##################
## Prepare data ##
##################
# Load data from Github URL
path <- "https://raw.githubusercontent.com/huang1010/HMMS/master/"
data1 <- read.csv(url(paste0(path, "S9.csv")))[,-3]
data2 <- read.csv(url(paste0(path, "S20.csv")))[,-3]

# Format data 
data <- rbind(cbind(ID = "S9", data1), 
              cbind(ID = "S20", data2))
data$time <- as.POSIXct(data$time)
# Get time of day
data$tod <- as.numeric(format(data$time, "%H")) + 
  as.numeric(format(data$time, "%M"))/60

# Plot activity against time for each individual
theme_set(theme_bw())
ggplot(data, aes(time, activity)) + 
  facet_wrap("ID", scales = "free_x") +
  geom_line()

###########################################
## Model specification and model fitting ##
###########################################
# Define Markov chain model
hid <- MarkovChain$new(data = data, n_states = 2, initial_state = "stationary",
                       formula = ~s(tod, k = 5, bs = "cc"))

# Define observation model
dists <- list(activity = "zigamma2")
par0 <- list(activity = list(mean = c(20, 150), 
                             sd = c(20, 40), 
                             z = c(0.1, 0)))
obs <- Observation$new(data = data, dists = dists,
                       n_states = 2, par = par0)

# Fixed parameter (zero mass in state 2)
fixpar <- list(obs = c("activity.z.state2.(Intercept)" = NA))

# Define HMM
hmm <- HMM$new(obs = obs, hid = hid, fixpar = fixpar)

# Prior specification
prior_obs <- matrix(c(log(20), 10,
                      log(150), 10,
                      log(20), 10,
                      log(40), 10,
                      qlogis(0.1), 10,
                      NA, NA),
                    ncol = 2, byrow = TRUE)
prior_hid <- matrix(c(qlogis(0.05), 10,
                      qlogis(0.05), 10),
                    ncol = 2, byrow = TRUE)

hmm$set_priors(list(coeff_fe_obs = prior_obs, 
                    coeff_fe_hid = prior_hid))

hmm$priors()

# Run 1 MCMC chain for 1000 iterations
n_chain <- 1
n_iter <- 1000
hmm$fit_stan(chains = n_chain, iter = n_iter, seed = 9745)

##################
## Stan outputs ##
##################
# Posterior samples
par_df <- as.data.frame.table(hmm$iters()[,c(1:4)])
colnames(par_df) <- c("iter", "par", "value")
ggplot(par_df, aes(x = value)) + 
  geom_histogram(bins = 30, col = "grey", fill = "lightgrey") +
  facet_wrap("par", scales = "free", ncol = 2) +
  labs(x = NULL)

# Summary of fit
hmm$out_stan()

########################################################
## Posterior samples of state-dependent distributions ##
########################################################
# Weights for state-specific density functions
w <- table(hmm$viterbi())/nrow(data)

# Select 100 random posterior samples
ind_post <- sort(sample(1:(n_iter/2), size = 100))

# For each posterior sample, compute gamma distribution
dens_df <- data.frame()
activity <- seq(0.4, 200, length = 100)
for(i in ind_post) {
  par <- hmm$iters()[i,1:6]
  dens1 <- obs$dists()$activity$pdf()(x = activity, mean = par[1], 
                                      sd = par[3], z = par[5])
  dens2 <- obs$dists()$activity$pdf()(x = activity, mean = par[2], 
                                      sd = par[4], z = par[6])
  dens <- data.frame(state = paste0("state ", rep(1:2, each = 100)),
                     dens = c(w[1] * dens1, w[2] * dens2))
  dens$group <- paste0("iter ", i, " - ", dens$state)
  dens_df <- rbind(dens_df, dens)
}
dens_df$activity <- activity

# Plot histogram of activity and density lines
ggplot(dens_df, aes(activity, dens)) +
  geom_histogram(aes(y = after_stat(density)), data = data, fill = "lightgrey", 
                 col = "grey", breaks = seq(0, 200, by = 10)) +
  geom_line(aes(col = state, group = group), linewidth = 0.1, alpha = 0.5) +
  scale_color_manual(values = hmmTMB:::hmmTMB_cols, name = NULL) +
  guides(color = guide_legend(override.aes = list(size = 0.5, alpha = 1)))

#################################################
## Posterior samples of stationary state probs ##
#################################################
# Select random posterior draws
ind_post <- sort(sample(1:(n_iter/2), size = 100))

# For each posterior sample, compute stationary state probabilities over
# grid of time of day
probs_df <- data.frame()
newdata <- data.frame(tod = seq(0, 24, length = 100))
for(i in ind_post) {
  hmm$update_par(iter = i)
  probs <- data.frame(state = rep(paste0("state ", 1:2), each = 100),
                      prob = as.vector(hmm$predict(what = "delta", 
                                                   newdata = newdata)))
  probs$group <- paste0("iter ", i, " - ", probs$state)
  probs_df <- rbind(probs_df, probs)
}
probs_df$tod <- newdata$tod

# Plot stationary state probs against time of day
ggplot(probs_df, aes(tod, prob, group = group, col = state)) + 
  geom_line(linewidth = 0.1, alpha = 0.5) +
  scale_x_continuous(breaks = seq(0, 24, by = 4)) +
  labs(x = "time of day", y = "stationary state probabilities", col = NULL) +
  scale_color_manual(values = hmmTMB:::hmmTMB_cols) +
  guides(color = guide_legend(override.aes = list(size = 0.5, alpha = 1)))

###########################################################################
## Replicate  the simulation study presented in Appendix D.1 ------
## of "hmmTMB: Hidden Markov Models with Flexible Covariate Effects in R"
###########################################################################


library("hmmTMB")
theme_set(theme_bw())
library(pbmcapply)
n_cores <- parallel::detectCores() - 2

# Set seed and use RNG supported by 'parallel'
RNGkind("L'Ecuyer-CMRG")
set.seed(3482)

# True relationship between covariate x1 and transition probabilities
my_tpm <- function(x1) {
  n <- length(x1)
  linpred <- matrix(NA, nrow = n, ncol = 2)
  linpred[,1] <- -3 + 3 * x1^2
  linpred[,2] <- -2 + sin(2*pi*x1/2)
  tpm <- array(NA, dim = c(2, 2, n))
  tpm[1,2,] <- plogis(linpred[,1])
  tpm[2,1,] <- plogis(linpred[,2])
  tpm[1,1,] <- 1 - tpm[1,2,]
  tpm[2,2,] <- 1 - tpm[2,1,]
  return(tpm)
}

# Grid of x1 values for plots
newdata <- data.frame(x1 = seq(-1, 1, length = 100))
tpms_true <- my_tpm(newdata$x1)

# Loop over iterations
n_iter <- 200
tpm_list <- pbmclapply(1:n_iter, function(iter) {
  ## Simulate covariate
  n <- 5000
  x1 <- rep(NA, n)
  x1[1] <- 0
  for(i in 2:n) {
    new <- x1[i-1] + rnorm(1, 0, 0.1)
    if(new < -1) x1[i] <- -1 + abs(new + 1)
    else if(new > 1) x1[i] <- 1 - abs(new - 1)
    else x1[i] <- new
  }
  # Derive transition probabilities
  tpm <- my_tpm(x1)
  
  # State process
  s <- rep(NA, n)
  s[1] <- 1
  for(i in 2:n) {
    s[i] <- sample(1:2, size = 1, prob = tpm[s[i-1],,i])
  }
  
  # Observation process
  mu <- c(-5, 5)
  z <- rnorm(n = n, mean = mu[s], sd = 1)
  
  ###############
  ## Fit model ##
  ###############
  data <- data.frame(z = z, x1 = x1)
  # State model
  f1 <- "~s(x1, k = 10, bs = 'cs')"
  f2 <- "~s(x1, k = 10, bs = 'cc')"
  f <- matrix(c(".", f2, f1, "."), nrow = 2)
  hid <- MarkovChain$new(data = data, formula = f, n_states = 2)
  # Observation model
  dists <- list(z = "norm")
  par0 <- list(z = list(mean = c(-5, 5), sd = c(1, 1)))
  obs <- Observation$new(data = data, dists = dists, 
                         n_states = 2, par = par0)
  
  hmm <- HMM$new(obs = obs, hid = hid)
  hmm$fit(silent = TRUE)
  return(hmm$predict(what = "tpm", newdata = newdata))
}, mc.cores = n_cores)

# Unpack estimates in data frame
df_est <- do.call(rbind, lapply(seq_along(tpm_list), function(i) {
  a <- tpm_list[[i]]
  df <- data.frame(iter = i, x1 = newdata$x1, val = c(a[1,2,], a[2,1,]),
                   prob = rep(c("1-2", "2-1"), each = nrow(newdata)))
  return(df)
}))
# Data frame of true parameter values
df_true <- data.frame(x1 = newdata$x1, 
                      val = c(tpms_true[1,2,], tpms_true[2,1,]),
                      prob = rep(c("1-2", "2-1"), each = nrow(newdata)))

# Plot Pr(1 -> 2)
ggplot(subset(df_est, prob == "1-2"), aes(x1, val)) + 
  geom_line(aes(group = iter), linewidth = 0.1, alpha = 0.5) + 
  geom_line(data = subset(df_true, prob == "1-2"), col = 2, linewidth = 1) +
  coord_cartesian(ylim = c(0, 1)) +
  labs(x = expression(x[1]), y = "Pr(1 -> 2)")

# Plot Pr(2 -> 1)
ggplot(subset(df_est, prob == "2-1"), aes(x1, val)) + 
  geom_line(aes(group = iter), linewidth = 0.1, alpha = 0.5) + 
  geom_line(data = subset(df_true, prob == "2-1"), col = 2, linewidth = 1) +
  coord_cartesian(ylim = c(0, 1)) +
  labs(x = expression(x[1]), y = "Pr(2 -> 1)")


###########################################################################
## Replicate   the simulation study presented in Appendix D.2 ----
## of "hmmTMB: Hidden Markov Models with Flexible Covariate Effects in R"
###########################################################################

library("hmmTMB")
theme_set(theme_bw())
library(pbmcapply)
n_cores <- parallel::detectCores() - 2

# Set seed and use RNG supported by 'parallel'
RNGkind("L'Ecuyer-CMRG")
set.seed(1411)

# True relationship between covariate x1 and observation parameters
my_obspar <- function(x1) {
  n <- length(x1)
  par <- matrix(NA, nrow = n, ncol = 2)
  par[,1] <- exp(1 + ifelse(x1 < -0.5 | x1 > 0.5, 0, sin(2*pi*x1)))
  par[,2] <- 10
  colnames(par) <- c("rate1", "rate2")
  return(par)
}

# Grid of covariate for plots
newdata <- data.frame(x1 = seq(-1, 1, length = 100))
par_true <- my_obspar(newdata$x1)

# Loop over iterations
n_iter <- 200
mu_list <- pbmclapply(1:n_iter, function(iter) {
  # Simulate covariate
  n <- 2000
  x1 <- rep(NA, n)
  x1[1] <- 0
  for(i in 2:n) {
    new <- x1[i-1] + rnorm(1, 0, 0.1)
    if(new < -1) x1[i] <- -1 + abs(new + 1)
    else if(new > 1) x1[i] <- 1 - abs(new - 1)
    else x1[i] <- new
  }
  # Derive observation parameters
  obspar <- my_obspar(x1)
  
  # State process
  tpm <- matrix(c(0.9, 0.1, 0.1, 0.9), nrow = 2)
  s <- rep(NA, n)
  s[1] <- 1
  for(i in 2:n) {
    s[i] <- sample(1:2, size = 1, prob = tpm[s[i-1],])
  }
  
  # Observation process
  rate <- sapply(1:n, function(i) obspar[i,s[i]])
  z <- rpois(n = n, lambda = rate)
  
  ###############
  ## Fit model ##
  ###############
  data <- data.frame(z = z, x1 = x1)
  # State model
  hid <- MarkovChain$new(data = data, n_states = 2)
  # Observation model
  dists <- list(z = "pois")
  par0 <- list(z = list(rate = c(3, 10)))
  f <- list(z = list(rate = ~ state1(s(x1, k = 15, bs = "cs"))))
  obs <- Observation$new(data = data, dists = dists, formulas = f,
                         n_states = 2, par = par0)
  
  hmm <- HMM$new(obs = obs, hid = hid)
  hmm$fit(silent = TRUE)
  return(hmm$predict(what = "obspar", newdata = newdata)["z.rate","state 1",])
}, mc.cores = n_cores)

# Unpack estimates into a data frame
df_est <- data.frame(x1 = newdata$x1,
                     val = unlist(mu_list),
                     iter = rep(1:n_iter, each = nrow(newdata)))

# Data frame of true parameter values
df_true <- data.frame(x1 = newdata$x1, val = par_true[,1])

# Plot truth and estimates
ggplot(df_est, aes(x1, val, group = iter)) +
  geom_line(linewidth = 0.1, alpha = 0.5) +
  geom_line(aes(group = NA), data = df_true, col = 2, linewidth = 1) +
  labs(x = expression(x[1]), y = expression(mu[1]))

###########################################################################
## Replicate the simulation study presented in Appendix D.3 -----
## of "hmmTMB: Hidden Markov Models with Flexible Covariate Effects in R"
###########################################################################

library("hmmTMB")
theme_set(theme_bw())
library(pbmcapply)
n_cores <- parallel::detectCores()

# Set seed and use RNG supported by 'parallel'
RNGkind("L'Ecuyer-CMRG")
set.seed(8145)

# Function to get transition probability matrix with random intercept
# from vector of time series IDs
rand_tpm <- function(ID) {
  n <- length(ID)
  n_ID <- length(unique(ID))
  linpred <- matrix(NA, nrow = n, ncol = 2)
  linpred[,1] <- -2.5 + rep(rnorm(n_ID, mean = 0, sd = 1), table(ID))
  linpred[,2] <- -2.5 + rep(rnorm(n_ID, mean = 0, sd = 0.5), table(ID))
  tpm <- array(NA, dim = c(2, 2, n))
  tpm[1,2,] <- plogis(linpred[,1])
  tpm[2,1,] <- plogis(linpred[,2])
  tpm[1,1,] <- 1 - tpm[1,2,]
  tpm[2,2,] <- 1 - tpm[2,1,]
  return(tpm)
}

# Number of individuals (random effect group)
n_ID <- 20
# Number of observations per individual
n_by_ID <- 500
# Total number of observations
n <- n_ID * n_by_ID
# Vector of IDs
ID <- rep(1:n_ID, each = n_by_ID)

# Loop over observations
n_iter <- 200
sd_list <- pbmclapply(1:n_iter, function(iter) {
  # State process
  tpm <- rand_tpm(ID)
  s <- rep(NA, n)
  for(id in 1:n_ID) {
    ind <- which(ID == id)
    s[ind[1]] <- 1
    for(i in 2:n_by_ID) {
      s[ind[i]] <- sample(1:2, size = 1, prob = tpm[s[ind[i]-1],,ind[i]])
    }    
  }
  
  # Observation process
  mean <- c(3, 15)
  sd <- c(2, 5)
  z <- rgamma(n = n, shape = mean[s]^2/sd[s]^2, rate = mean[s]/sd[s]^2)
  
  ###############
  ## Fit model ##
  ###############
  data <- data.frame(ID = factor(ID), z = z)
  # State model
  f <- ~ s(ID, bs = "re")
  hid <- MarkovChain$new(data = data, formula = f, n_states = 2)
  # Observation model
  dists <- list(z = "gamma2")
  par0 <- list(z = list(mean = c(3, 15), sd = c(2, 5)))
  obs <- Observation$new(data = data, dists = dists,
                         n_states = 2, par = par0)
  
  hmm <- HMM$new(obs = obs, hid = hid)
  hmm$fit(silent = TRUE)
  # Return estimated SD of random effects
  return(hmm$sd_re()$hid)
}, mc.cores = n_cores)

# Unpack estimated standard deviations in data frame
sd_df <- as.data.frame(t(do.call(cbind, sd_list)))

# Plot histogram of standard deviations in each state
ggplot(sd_df, aes(x = `S1>S2.s(ID)`)) + 
  geom_histogram(fill = adjustcolor("firebrick", 0.5), bins = 25) + 
  geom_histogram(aes(x = `S2>S1.s(ID)`), fill = adjustcolor("royalblue", 0.5), bins = 25) +
  geom_vline(xintercept = 1, lty = 2, linewidth = 1, col = "firebrick") +
  geom_vline(xintercept = 0.5, lty = 2, linewidth = 1, col = "royalblue") +
  labs(x = "random effect standard deviation") +
  annotate("text", x = 1.2, y = 20, col = "firebrick", label = "Pr(1 -> 2)") +
  annotate("text", x = 0.3, y = 20, col = "royalblue", label = "Pr(2 -> 1)")


sessionInfo()
