##########################################################################################
## Replication script for:                                                              ##
## "Efficient Multiple Imputation for Diverse Data in Python and R: MIDASpy and rMIDAS" ##
##########################################################################################

## See README for more detailed guidance on replication
## Due to the time and memory required to run all tests presented in this paper,  
## we **substantively** reduce both the Python imputation workflow and then  
## replicate the remaining figures and output

#########################
## Imputation workflow ##
#########################
## Load R dependencies (install if not previously installed)
library("ggplot2")
library("data.table")
library("mltools")
library("reshape2")
library("reticulate")

## Install and load rMIDAS
# install.packages("rMIDAS")
library("rMIDAS")

## IMPORTANT: please follow these steps in your Python terminal:
##   1. Navigate to the replication folder from the command line
##   2. Make sure conda is installed (on M1/M2 chipset Macs, we recommend using miniforge)
##   3. Set up a new conda environment by running the following:
##     * M1 or M2 Mac: `conda env create -f Data/midas-env-arm64.yml`
##     * All other systems (including Intel-based Macs): `conda env create -f Data/midas-env.yml`

## Now that the "rmidas" conda environment has been created, proceed in R:

set_python_env("rmidas", type = "conda")
## Make sure to set the main replication folder as the working directory

############################
## Set up for replication ##
############################
## Prevent GPU computation if available
reticulate::py_run_string("import os; os.environ['CUDA_VISIBLE_DEVICES'] = '-1'")

if(!dir.exists("Figures")) {
  dir.create("Figures")
}

##########################################
## Section 5.1 -- MIDASpy demonstration ##
##########################################

t_0 <- Sys.time()

py_run_string(
'import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
import sys
import MIDASpy as md

data_in = pd.read_csv("Data/cces_jss_format.csv")

cont_vars = ["citylength_1", "numchildren", "birthyr"]
vals = data_in.nunique()

cat_vars = ["CC18_401", "CC18_413d", "educ", "race", "marstat", "votereg",
              "OScode", "CompRating", "region", "pid7", "immstat", "employ",
              "sexuality", "trans", "industryclass", "pew_churatd"]

bin_vars = ["CC18_415a", "CC18_417_a", "gender"]

data_bin = data_in[bin_vars].apply(md.binary_conv)

constructor_list = [data_in[cont_vars], data_bin]

data_cat = data_in[cat_vars]

data_oh, cat_col_list = md.cat_conv(data_cat)

constructor_list.append(data_oh)

data_0 = pd.concat(constructor_list, axis = 1)

scaler = MinMaxScaler()

data_scaled = scaler.fit_transform(data_0)
data_scaled = pd.DataFrame(data_scaled, columns = data_0.columns)

na_loc = data_scaled.isnull()
data_scaled[na_loc] = np.nan

imputer = md.Midas(layer_structure = [256, 256],
                       vae_layer = False,
                       seed = 89,
                       input_drop = 0.75)

imputer.build_model(data_scaled,
                    binary_columns = bin_vars,
                    softmax_columns =  cat_col_list)

imputer.train_model(training_epochs = 10)

imputations = imputer.yield_samples(m = 10)

analysis_dfs = []

for df in imputations:
  df_unscaled = scaler.inverse_transform(df)
  df_unscaled = pd.DataFrame(df_unscaled, columns = data_scaled.columns)
  df["age"] = 2018 - df_unscaled["birthyr"]
  df["CC18_415a"] = np.where(df_unscaled["CC18_415a"] >= 0.5, 1, 0)
  analysis_dfs.append(df.loc[:,["age", "CC18_415a"]])

model = md.combine(y_var = "CC18_415a",
                   X_vars = ["age"],
                   df_list = analysis_dfs)
model
'
)

py_run_string("print(model)")

t_python <- Sys.time()

py_run_string("import gc; del(analysis_dfs, imputations, imputer, data_scaled, data_oh, data_cat, data_bin, data_in); gc.collect()")

#########################################
## Section 5.2 -- rMIDAS demonstration ##
#########################################

set.seed(89)

## Read data for analysis
cat_vars <- c("CC18_401", "CC18_413d", "educ", "race", "marstat", "votereg",
              "OScode", "CompRating", "region", "pid7", "immstat", "employ",
              "sexuality", "trans", "industryclass", "pew_churatd")

bin_vars <- c("CC18_415a", "CC18_417_a", "gender")

cont_vars <- c("citylength_1", "numchildren", "birthyr")

data_0 <- fread("Data/cces_jss_format.csv", 
                select = c(cat_vars, bin_vars, cont_vars))

t_readin <- Sys.time()

#### A. Basic imputation demonstration ####
## Step 1. Pre-process data for imputation
data_conv <- convert(data_0,
                     bin_cols = bin_vars,
                     cat_cols = cat_vars,
                     minmax_scale = TRUE)

t_convert <- Sys.time()

## Step 2. Build and train imputation model
data_train <- train(data_conv,
                    layer_structure = c(256, 256),
                    vae_layer = FALSE,
                    seed = 89,
                    input_drop = 0.75,
                    training_epochs = 10)

t_train <- Sys.time()

## Step 3. Generate completed datasets
imputations <- complete(data_train, m = 10)

t_complete <- Sys.time()

## Demonstrate imputation results
analysis_vars <- c("CC18_401", "CC18_413d", "CC18_415a", "CC18_417_a",
                   "citylength_1", "numchildren")

imps_stack <- imputations[[1]][, analysis_vars]

## Here we stack the data.frames to make plotting easier
for (i in 2:10) {
  imps_stack <- rbind(imps_stack, imputations[[i]][, analysis_vars])
}

imps_stack$imputation <- rep(1:10, each = nrow(imputations[[1]]))
imps_stack$imputation <- factor(imps_stack$imputation)

## Plot example variables

## Isolate missing values
var_missing <- rep(is.na(data_0$citylength_1), times = 10)

ggplot(imps_stack[var_missing,], aes(x = citylength_1, fill = imputation)) +
  geom_density(alpha = 0.4) +
  labs(x = "Prediction for citylength_1", y = "Density") +
  theme_minimal() +
  guides(fill = guide_legend(title = "Set of imputations (M)",
                             title.position = "top", nrow = 1,
                             title.hjust = 0.5)) +
  theme(legend.position = "bottom")

## Generate average imputed values (i.e. cell means across 10 complete data.frames)
imps <- lapply(imputations, function (x)
  as.data.frame(lapply(x[, c(analysis_vars, "birthyr")], function (y) as.numeric(y))))

imps_avg <- round(Reduce(`+`, imps) / length(imps))[, analysis_vars]
imps_avg$type <- "Average MIDAS imputation (M=10)"

incomplete <- data_0[,..analysis_vars]
incomplete$type <- "Original value"

imps_comparison <- reshape2::melt(rbind(imps_avg, incomplete), id.vars = "type")

ggplot(imps_comparison, aes(x = value, fill = type)) +
  facet_wrap(~ variable, scales = "free") +
  scale_fill_manual(values = c("blue", "red")) +
  geom_density(alpha = 0.5) +
  labs(x = "Value", y = "Density", fill = "") +
  theme_minimal() +
  theme(legend.position = "bottom")

## Regression example
for (d in 1:10) {
  imputations[[d]]$age <- 2018 - imputations[[d]]$birthyr
  imputations[[d]]$CC18_415a <- ifelse(imputations[[d]]$CC18_415a == 1, 1, 0)
}

combine("CC18_415a ~ age", imputations)

## Remove imputations to restore RAM
rm(imputations, imps_stack, imps_comparison, imps, data_train)
gc(verbose = FALSE)

t_analyse <- Sys.time()

###################################
## Section 6.1 -- overimputation ##
###################################
## To replicate this graph using the full data please see Code/full_code.R
## Figures 5, 6 and 7 will be saved in Figures/
overimpute(data_conv,
           layer_structure = c(256, 256),
           vae_layer = FALSE,
           seed = 89,
           spikein = 0.3, 
           training_epochs = 100, 
           report_ival = 25, 
           plot_vars = TRUE, 
           spike_seed = 89,
           save_path = "Figures/")

## Tidy up output files
rm_figs <- list.files("Figures")[grep("_epoch_", list.files("Figures"))]
rm_figs <- rm_figs[!grepl("industryclass_epoch_100|numchildren_epoch_100", rm_figs)]
unlink(paste0("Figures/", rm_figs))

t_overimp <- Sys.time()

###################
## Section 6.2 ####
###################

## Hyperparameter test
# Given the time taken to generate 12 models and the JSS runtime requirements,
# we use saved output data to replicate the figure. The output-generating code 
# is included as a comment below

# nodes <- c(64, 128, 256, 512)
# layers <- c(2,3,4)
# node_layers <- list()
#
# for (n in nodes) {
#   for (l in layers) {
#     node_layers[[length(node_layers) + 1]] <- rep(n, l)
#   }
# }
#
# # Run different spec MIDAS models
# # note: python output captured and so won't be displayed in console
# hyper_res <- py_capture_output(
#
#   for (nl in node_layers) {
#
#     print("Node structure: ")
#     print(nl)
#
#     overimpute(data_conv,
#                layer_structure= nl,
#                vae_layer= FALSE,
#                seed= 89,
#                input_drop = 0.75,
#                spikein = 0.3,
#                training_epochs = 100,
#                report_ival = 100,
#                plot_vars = FALSE,
#                skip_plot = TRUE,
#                spike_seed = 89)
#
#   },
#
#   type = c("stdout")
# )
#
# # Convert text output to data.frame
# hyper_vec <- strsplit(hyper_res, "\n")[[1]]
#
# softmax_agg = c()
#
# for (line in hyper_vec) {
#
#   if (grepl("Aggregated error on softmax spike-in", line)) {
#
#     softmax_agg <- append(softmax_agg, substr(line,38,nchar(line)))

#   }
# }
#
# hyperp_data <- data.frame(
#   epochs = rep(c(0,100),length(node_layers)),
#   layers = rep(layers,each=2),
#   nodes = rep(nodes, each = 2*length(layers)),
#   softmax_agg = as.numeric(softmax_agg))
#
# write.csv(hyperp_data, "Data/hyperparameter_results.csv")

## Figure replication:

hyperp_data <- read.csv("Data/hyperparameter_results.csv")

ggplot(hyperp_data[hyperp_data$epochs == 100,],
       aes(x = as.factor(nodes), y = as.factor(layers))) +
  scale_fill_gradient(low = "olivedrab2", high = "firebrick2") + 
  geom_tile(aes(fill = softmax_agg)) +
  geom_text(aes(label = signif(softmax_agg, 3))) +
  theme_minimal() +
  labs(x = "Nodes per layer", y = "No. of hidden layers",
       fill = "Aggregated error on softmax (categorical)") +
  theme(legend.position = "bottom",
        legend.title.align = 0.5) +
  guides(fill = guide_colorbar(title.position = "top", barwidth = 20))

t_hyper <- Sys.time()

## Learning rate test

learn_rates <- c(0.00001, 0.0001, 0.001, 0.01, 0.1)

# lr_out <- py_capture_output(
#
#   {
#     for (lr in learn_rates) {
#
#       train(data_conv,
#             layer_structure = c(128, 128),
#             vae_layer = FALSE,
#             seed = 89,
#             input_drop = 0.75,
#             training_epochs = 20,
#             learn_rate = lr)
#     }
#   }
# )
#
# write.csv(lr_out, "Data/learn_rate_results.csv")

## Figure replication
lr_out <- readLines("Data/learn_rate_results.csv")

loss_res <- lr_out[grep("Epoch: ", lr_out)]
loss_vec <- gsub("Epoch: .* , loss: ", "", loss_res)

lr_results <- data.frame(learn_rate = factor(rep(learn_rates, each = 20)),
                         epoch = 0:19,
                         loss = log(as.numeric(loss_vec))
)

ggplot(lr_results, aes(x = epoch, y = loss, color = learn_rate)) +
  geom_point() +
  geom_line() +
  labs(x = "Training epoch", y = "Training loss (log)",
       color = "Learning rate") +
  theme_minimal() +
  scale_color_brewer(palette = "Set1") +
  scale_x_continuous(breaks = 0:19) +
  theme(legend.position = "bottom")

t_learn <- Sys.time()

#########################################
## Section 6.3 Variational autoencoder ##
#########################################

## 4 models, differing no. of layers, toggling VAE layer
vae_imps <- list()
vae_layer_struc <- list(c(512, 512), c(128, 128))

for (layer_struc in vae_layer_struc) {
  for (t_val in c(TRUE, FALSE)) {

    mod_name <- paste0("mod_",layer_struc[1], ifelse(t_val, "_vae", ""))

    print(mod_name)

    vae_mod <- train(data_conv,
                    layer_structure = layer_struc,
                    vae_layer = t_val,
                    seed = 89,
                    input_drop = 0.75,
                    training_epochs = 10)

    vae_completes <- complete(vae_mod, m = 10)
    vae_imps[[mod_name]] <- lapply(vae_completes, function(x) x[, analysis_vars])
    vae_imps[[mod_name]] <- lapply(vae_imps[[mod_name]], function (x)
      as.data.frame(lapply(x, function (y) as.numeric(y))))

    vae_mod <- ""
    vae_completes <- ""
    gc()
  }
}

# Average datasets
vae_results <- lapply(vae_imps, function (x) round(Reduce(`+`, x) / length(x)))

# Get original data
orig_data <- data_0[, ..analysis_vars]

# Extract results for 512 model
plot_512 <- reshape2::melt(

  cbind(
    reshape2::melt(vae_results[["mod_512_vae"]], value.name = "vae"),
    original = reshape2::melt(orig_data, value.name = "original")[, "original"]
  ),
  variable.name = "name"
)

plot_512$name <- ifelse(plot_512$name == "vae", "Average MIDAS imputation (M = 10)", "Original value")

ggplot(plot_512, aes(x = value, fill = name)) +
  facet_wrap(~variable, ncol = 3, scales = "free") +
  geom_density(alpha = 0.5) +
  theme_minimal() +
  labs(x = "(Predicted) Value", y = "Density", fill = "Model") +
  theme(legend.position = "bottom")

# Comparison between normal and VAE MIDAS networks
vae_comparison <- data.frame(orig = orig_data$citylength_1,
                             imp = c(vae_results[["mod_512"]]$citylength_1, vae_results[["mod_128"]]$citylength_1),
                             vae = c(vae_results[["mod_512_vae"]]$citylength_1, vae_results[["mod_128_vae"]]$citylength_1),
                             network = rep(c("Network Structure: [512,512]",
                                             "Network Structure: [128,128]"),
                                           each = 60000))

vae_comparison$missing <- ifelse(is.na(vae_comparison$orig), "Missing", "Not missing")

ggplot(vae_comparison, aes(x = vae, y = imp, color = missing)) +
  facet_wrap(~ network, ncol = 2) +
  scale_color_manual(values = c("red3", "grey30")) +
  geom_point(alpha = 0.5) +
  geom_abline(linetype = "dashed") +
  xlim(0, 50) +
  ylim(0, 50) +
  labs(x = "Variational MIDAS imputation", y = "Regular MIDAS imputation", color = "") +
  theme_minimal() +
  theme(legend.position = "bottom")

t_vae <- Sys.time()

## -------------------------------------------------------------
#### REPLICATION RUNTIME & SYS.INFO ####

## Total runtime
print(t_vae - t_0)

## Timing breakdowns
## Python full imputation
print(t_python - t_0)

## Example imputation
print(t_complete - t_readin)

## Analysis + remove data to ease memory usage
print(t_analyse - t_complete)

## Overimputation
print(t_overimp - t_analyse)

## Hyperparameter test
print(t_hyper - t_overimp)

## Learning rate test
print(t_learn - t_hyper)

## VAE test
print(t_vae - t_learn)

Sys.info()
sessionInfo()
