library(markets)
library(filelock)
library(microbenchmark)

cargs <- commandArgs()

# Model to be benchmarked
model_string <- sub("--model=", "", cargs[grepl("--model=", cargs)])
# Exponent of the number of observations
nexp <- as.numeric(sub("--nexp=", "", cargs[grepl("--nexp=", cargs)]))
# Number of entities
nobs <- 2 * 2**nexp + 10
# Number of time-points
tobs <- 5
# Number of repetitions used
reps <- as.numeric(sub("--reps=", "", cargs[grepl("--reps=", cargs)]))
# Number of extra simulated parameters
addpars <- as.numeric(sub("--addpars=", "", cargs[grepl("--addpars=", cargs)]))
addpars <- ifelse(length(addpars) > 0, addpars, 0)
# Total number of parameters
params <- 14 + 2 * addpars
if (model_string == "diseq_deterministic_adjustment") {
    params <- params + 1
} else if (model_string == "diseq_stochastic_adjustment") {
    params <- params + 6
}
# Output filename
filename <- sub("--out=", "", cargs[grepl("--out=", cargs)])
# Global evaluation counter (used for benchmark bookkeeping)
evaluation <- 0
# Global model object
model <- NULL
# Commonly used optimization controls
optimization_controls <- list(maxit = 1e+5)


#' Make Message
#'
#' Prepare diagnostic message and bundle it with model information.
#'
#' @param ... Message data.
#' @return A diagnostic message string.
make_message <- function(...) {
    paste0(
        model_string, " with ", nobs * tobs, " obs, ", params, " params, and ",
        reps, " reps: ", ..., "\n"
    )
}

#' Print Message
#'
#' Print diagnostic message.
#'
#' @param ... Message data.
#' @return No return, called for side effects (printing).
print_message <- function(...) {
    cat(make_message(...))
}

#' Randomized coefficient
#'
#' Draw random model parameter from a normal distribution. The mean and standard deviation
#' is passed by the user. By default the standard deviation is set to 0.1.
#'
#' @param mean Distribution mean.
#' @param sd Distribution standard deviation.
#' @return A randomly drawn coefficient value.
noisy_coefficient <- function(mean, sd = 0.1) {
    rnorm(1, mean = mean, sd = sd)
}

#' Prepare simulation parameters
#'
#' Prepare the model parameter vector used in the simulation for the benchmark. The model
#' parameters are randomly drawn from normal distribution. Each model has a different set
#' of parameters and, therefore, handled separately. If the script is called with addpars
#' set (i.e. for the benchmark with varying number of coefficients), then the function
#' also randomizes these parameters.
#'
#' @param model Model type string.
#' @return A vector of coefficients values to be used in the simulation.
prepare_simulation_parameters <- function(model) {
    sim <- list(
        D_P = 0.0, D_CONST = 0.0, D_Xd1 = 0.0, D_Xd2 = 0.0,
        D_X1 = 0.0, D_X2 = 0.0,
        S_P = 0.0, S_CONST = 0.0, S_Xs1 = 0.0,
        S_X1 = 0.0, S_X2 = 0.0,
        P_DIFF = 1.0, P_CONST = 0.0, P_Xp1 = 0.0,
        D_VARIANCE = 1.0, S_VARIANCE = 1.0, P_VARIANCE = 1.0,
        RHO = 0.0, RHO_DP = 0.0, RHO_SP = 0.0
    )

    if (model == "diseq_basic") {
        sim$D_P <- noisy_coefficient(-0.5)
        sim$D_CONST <- noisy_coefficient(10.9)
        sim$D_Xd1 <- noisy_coefficient(0.3)
        sim$D_Xd2 <- noisy_coefficient(-0.2)
        sim$D_X1 <- noisy_coefficient(-0.3)
        sim$D_X2 <- noisy_coefficient(-0.1)

        sim$S_P <- noisy_coefficient(0.5)
        sim$S_CONST <- noisy_coefficient(9.8)
        sim$S_Xs1 <- noisy_coefficient(0.1)
        sim$S_X1 <- noisy_coefficient(0.3)
        sim$S_X2 <- noisy_coefficient(0.2)

        sim$D_VARIANCE <- noisy_coefficient(0.9)
        sim$S_VARIANCE <- noisy_coefficient(1.2)
        sim$RHO <- noisy_coefficient(0.0)
    } else if (model == "diseq_directional") {
        sim$D_P <- noisy_coefficient(-1.7)
        sim$D_CONST <- noisy_coefficient(24.1)
        sim$D_Xd1 <- noisy_coefficient(2.3)
        sim$D_Xd2 <- noisy_coefficient(1.2)
        sim$D_X1 <- noisy_coefficient(1.3)
        sim$D_X2 <- noisy_coefficient(1.1)

        sim$S_P <- 0.0
        sim$S_CONST <- noisy_coefficient(23.0)
        sim$S_Xs1 <- noisy_coefficient(-1.3)
        sim$S_X1 <- noisy_coefficient(-1.5)
        sim$S_X2 <- noisy_coefficient(1.2)

        sim$D_VARIANCE <- noisy_coefficient(1.0)
        sim$S_VARIANCE <- noisy_coefficient(1.2)
        sim$RHO <- noisy_coefficient(0.0)
    } else if (model == "diseq_deterministic_adjustment") {
        sim$D_P <- noisy_coefficient(-2.7)
        sim$D_CONST <- noisy_coefficient(39.9)
        sim$D_Xd1 <- noisy_coefficient(2.1)
        sim$D_Xd2 <- noisy_coefficient(-0.7)
        sim$D_X1 <- noisy_coefficient(3.5)
        sim$D_X2 <- noisy_coefficient(6.25)

        sim$S_P <- noisy_coefficient(2.8)
        sim$S_CONST <- noisy_coefficient(33.2)
        sim$S_Xs1 <- noisy_coefficient(0.65)
        sim$S_X1 <- noisy_coefficient(0.15)
        sim$S_X2 <- noisy_coefficient(4.2)

        sim$P_DIFF <- noisy_coefficient(1.2)

        sim$D_VARIANCE <- noisy_coefficient(1.0)
        sim$S_VARIANCE <- noisy_coefficient(1.0)
        sim$RHO <- noisy_coefficient(0.0)
    } else if (model == "diseq_stochastic_adjustment") {
        sim$D_P <- noisy_coefficient(-0.3)
        sim$D_CONST <- noisy_coefficient(8.9)
        sim$D_Xd1 <- noisy_coefficient(-0.1)
        sim$D_Xd2 <- noisy_coefficient(-0.7)
        sim$D_X1 <- noisy_coefficient(0.5)
        sim$D_X2 <- noisy_coefficient(0.25)

        sim$S_P <- noisy_coefficient(0.4)
        sim$S_CONST <- noisy_coefficient(5.2)
        sim$S_Xs1 <- noisy_coefficient(0.65)
        sim$S_X1 <- noisy_coefficient(-0.15)
        sim$S_X2 <- noisy_coefficient(-0.2)

        sim$P_DIFF <- noisy_coefficient(0.3)
        sim$P_CONST <- noisy_coefficient(0.2)
        sim$P_Xp1 <- noisy_coefficient(0.65)

        sim$D_VARIANCE <- noisy_coefficient(1.0)
        sim$S_VARIANCE <- noisy_coefficient(1.0)
        sim$P_VARIANCE <- noisy_coefficient(1.0)
        sim$RHO <- noisy_coefficient(0.0)
        sim$RHO_DP <- noisy_coefficient(0.0)
        sim$RHO_SP <- noisy_coefficient(0.0)
    } else if (model %in% c("equilibrium_model")) {
        sim$D_P <- noisy_coefficient(-1.9)
        sim$D_CONST <- noisy_coefficient(28.9)
        sim$D_Xd1 <- noisy_coefficient(-4.1)
        sim$D_Xd2 <- noisy_coefficient(1.7)
        sim$D_X1 <- noisy_coefficient(-3.5)
        sim$D_X2 <- noisy_coefficient(3.25)

        sim$S_P <- noisy_coefficient(5.1)
        sim$S_CONST <- noisy_coefficient(18.2)
        sim$S_Xs1 <- noisy_coefficient(4.6)
        sim$S_X1 <- noisy_coefficient(3.1)
        sim$S_X2 <- noisy_coefficient(2.2)

        sim$D_VARIANCE <- noisy_coefficient(1.0)
        sim$S_VARIANCE <- noisy_coefficient(1.0)
        sim$RHO <- noisy_coefficient(0.0)
    } else {
        stop("Unhandled model type.")
    }

    if (addpars > 0) {
        extra_demand <- sapply(3:(2 + addpars), function(i) noisy_coefficient(0.0))
        names(extra_demand) <- sapply(
            3:(2 + addpars),
            function(i) sprintf("D_Xd%d", i)
        )
        extra_supply <- sapply(2:(1 + addpars), function(i) noisy_coefficient(0.0))
        names(extra_supply) <- sapply(
            2:(1 + addpars),
            function(i) sprintf("S_Xs%d", i)
        )
        sim <- c(sim, extra_demand, extra_supply)
    }

    sim
}

#' Test initialization
#'
#' Test whether the model can be initialized and estimated for some randomly generated
#' coefficients. The function is called by `microbenchmark` before each set of the three
#' timed execution calls (BFGS with calculated gradient, BFGS with numerical gradient,
#' and Nelder-Mead). A global counter keeps track of the calls and re-initializes model
#' data every third time that the function is successfully called. If the function throws
#' an exception, the global counter is not incremented. In this way, all three timed
#' execution calls use the same data at each benchmark round. The function tests whether
#' the simulated model can be estimated with the numerically calculated gradient.
#' This functionality is used by the `setup_time_benchmark` function to
#' ensure that generated data are only gathered for models that can be estimated.
#'
#' @return The initialed model upon success.
test_initialization <- function() {
    if (evaluation %% 3 == 0) {
        parameters <- prepare_simulation_parameters(model_string)
        beta_d <- c(parameters$D_Xd1, parameters$D_Xd2)
        beta_s <- c(parameters$S_Xs1)
        if (addpars > 0) {
            beta_d <- c(
                beta_d,
                sapply(3:(2 + addpars), function(i) parameters[[sprintf("D_Xd%d", i)]])
            )
            beta_s <- c(
                beta_s,
                sapply(2:(1 + addpars), function(i) parameters[[sprintf("S_Xs%d", i)]])
            )
        }
        simulation_parameters <- list(
            alpha_d = parameters$D_P,
            beta_d0 = parameters$D_CONST,
            beta_d = beta_d,
            eta_d = c(
                parameters$D_X1,
                parameters$D_X2
            ),
            alpha_s = parameters$S_P,
            beta_s0 = parameters$S_CONST,
            beta_s = beta_s,
            eta_s = c(
                parameters$S_X1,
                parameters$S_X2
            ),
            gamma = parameters$P_DIFF,
            beta_p0 = parameters$P_CONST,
            beta_p = c(parameters$P_Xp1),
            sigma_d = parameters$D_VARIANCE,
            sigma_s = parameters$S_VARIANCE,
            sigma_p = parameters$P_VARIANCE,
            rho_ds = parameters$RHO,
            rho_dp = parameters$RHO_DP,
            rho_sp = parameters$RHO_SP
        )

        subject <- "id"
        time <- "date"
        quantity <- "Q"
        price <- "P"
        demand <- paste0(
            price, " + Xd1 + Xd2 + ",
            ifelse(addpars > 0,
                paste0(paste("Xd", 3:(2 + addpars), sep = ""), sep = " + "), ""
            ),
            "X1 + X2"
        )
        supply <- paste0(
            "Xs1 + ",
            ifelse(addpars > 0,
                paste0(paste("Xs", 2:(1 + addpars), sep = ""), sep = " + "), ""
            ),
            "X1 + X2"
        )
        if (model_string != "diseq_directional") { 
            supply <- paste0(price, " + ", supply)
        }

        specification <- paste0(
            paste(quantity, price, subject, time, sep = " | "), " ~ ",
            paste(demand, supply, sep = " | ")
        )

        if (model_string == "diseq_stochastic_adjustment") {
            price_dynamics <- "Xp1"

            specification <- paste(specification, price_dynamics, sep = " | ")
        }
       
        verbose <- 0
        correlated_shocks <- TRUE

        model_tibble <- do.call(markets::simulate_data, c(
            model_type_string = model_string,
            nobs = nobs, tobs = tobs,
            simulation_parameters,
            verbose = verbose
            ))

        model <<- do.call(
            markets:::initialize_from_formula,
            list(model_type = model_string,
                 specification = formula(specification),
                 data = model_tibble,
                 correlated_shocks = correlated_shocks, verbose = verbose)
        )

        markets::estimate(model,
            control = optimization_controls, method = "BFGS",
            gradient = "numerical", hessian = "skip"
        )
        print_message("Initialization succeeded for evaluation ", evaluation / 3)
    }


    evaluation <<- evaluation + 1

    model
}

#' Setup time benchmark
#'
#' Simulates the model data used in a time benchmark round and initializes the model. Because
#' the estimation of the model with BFGS using numerical gradient is the most prone to fail,
#' the function ensures that the model can be estimated for the simulated data with this method.
#' This is achieved by calling `test_initialization`, which estimates the model. If
#' `test_initialization` throws and exception, the setup process starts over for a maximum
#' of 10000 attempts. If all attempts fail, the execution of the benchmark script stops.
#'
#' @return No return, called for side effects (preparing and validating the data for the benchmark).
setup_time_benchmark <- function() {
    initialized <- FALSE
    error_counter <- 0
    while (!initialized) {
        tryCatch(
            {
                test_initialization()
                initialized <- TRUE
            },
            error = function(e) {
                error_counter <<- error_counter + 1
                if (error_counter > 10000) {
                    stop(make_message("Initialization failed."))
                } else {
                    message <- as.character(e)
                    message <- substr(message, 1, nchar(message) - 2)
                    print_message(sprintf("%s...Retrying", message))
                }
            }
        )
    }
}

#' Estimate using BFGS with calculated gradient
#'
#' The function is used by `microbenchmark` to measure estimation times using BFGS with
#' calculated gradient.
#'
#' @return An estimated market model.
bfgs_with_calculated_gradient <- function() {
    markets::estimate(model,
        control = optimization_controls, method = "BFGS",
        hessian = "skip"
    )
}

#' Estimate using BFGS with numerically approximated gradient
#'
#' The function is used by `microbenchmark` to measure estimation times using BFGS
#' with numerically approximated gradient.
#'
#' @return An estimated market model.
bfgs_with_numerical_gradient <- function() {
    markets::estimate(model,
        control = optimization_controls, method = "BFGS",
        gradient = "numerical", hessian = "skip"
    )
}

#' Estimate using Nelder-Mead
#'
#' The function is used by `microbenchmark` to measure estimation times using Nelder-Mead.
#'
#' @return An estimated market model.
nm <- function() {
    markets::estimate(model,
        control = optimization_controls, method = "Nelder-Mead",
        hessian = "skip"
    )
}

#' Save time statistics
#'
#' Save the statistics of the time measurements of each benchmark round in the `time_benchmark.rds`
#' file. The function acquires a lock to the file before any write operation because
#' benchmark operations are performed in parallel. It converts the gathered measurements to seconds
#' and calculates the 25, 50, 75 percentiles, median, min, max, and standard deviation.
#'
#' @return No return, called for side effects (saving benchmark data to disk).
save_time_statistics <- function(bts) {
    # convert measurements to seconds
    bfgs_numerical_grad_results <-
        bts[bts$expr == "bfgs_with_numerical_gradient()", ]$time / 10**9
    bfgs_calculated_grad_results <-
        bts[bts$expr == "bfgs_with_calculated_gradient()", ]$time / 10**9
    nm_results <- bts[bts$expr == "nm()", ]$time / 10**9

    # lock the file because many process run in parallel
    lck <- filelock::lock("time_benchmarks.lck")
    results <- readRDS(filename)
    results <- tibble::add_row(
        results,
        model = rep(model_string, 3),
        nobs = rep(nobs, 3), tobs = rep(tobs, 3),
        addpars = rep(2 * addpars, 3),
        reps = rep(reps, 3),
        expr = c("bfgs_numerical", "bfgs_calculated", "nm"),
        min = c(
            min(bfgs_numerical_grad_results),
            min(bfgs_calculated_grad_results),
            min(nm_results)
        ),
        lq = c(
            quantile(bfgs_numerical_grad_results, .25),
            quantile(bfgs_calculated_grad_results, .25), quantile(nm_results, .25)
        ),
        mean = c(
            mean(bfgs_numerical_grad_results),
            mean(bfgs_calculated_grad_results), mean(nm_results)
        ),
        median = c(
            median(bfgs_numerical_grad_results),
            median(bfgs_calculated_grad_results), median(nm_results)
        ),
        uq = c(
            quantile(bfgs_numerical_grad_results, .75),
            quantile(bfgs_calculated_grad_results, .75), quantile(nm_results, .75)
        ),
        max = c(
            max(bfgs_numerical_grad_results), max(bfgs_calculated_grad_results),
            max(nm_results)
        ),
        sd = c(
            sd(bfgs_numerical_grad_results), sd(bfgs_calculated_grad_results),
            sd(nm_results)
        )
    )
    saveRDS(results, filename)
    lck <- filelock::unlock(lck)
    print_message("Saved time benchmark record")
}

#' Check if record exists
#'
#' Check if a time measurement record for the same model, with the same number of
#' entity and time observations, additional parameters, and benchmark repetitions already
#' exists in the stored file.
#'
#' @return TRUE if the time record is stored in the data. False otherwise.
time_record_exists <- function() {
    result <- FALSE
    lck <- filelock::lock("time_benchmarks.lck")
    if (!file.exists(filename)) {
        data <- tibble::tibble(
            model = character(),
            nobs = numeric(),
            tobs = numeric(),
            addpars = numeric(),
            reps = numeric(),
            expr = character(),
            min = numeric(),
            lq = numeric(),
            mean = numeric(),
            median = numeric(),
            uq = numeric(),
            max = numeric(),
            sd = numeric()
        )
        saveRDS(data, filename)
    } else {
        data <- readRDS(filename)
        match <- (
            data$model == model_string &
                data$nobs == nobs &
                data$tobs == tobs &
                data$addpars == 2 * addpars &
                data$reps == reps)
        if (sum(match)) {
            print_message("Skipping benchmark because saved record exists")
            result <- TRUE
        }
    }
    lck <- filelock::unlock(lck)

    result
}

#' Gather time measurement data
#'
#' The function executes the time benchmarks and measures statistics for the execution times
#' of estimating market models with BFGS with calculated gradient, BFGS with numerical
#' gradient, and Nelder-Mead.
#'
#' @return No return, called for side effects (executing and saving benchmark data).
benchmark_time <- function() {
    if (!time_record_exists()) {
        print_message("Starting time benchmark")
        benchmark_results <- microbenchmark(
            bfgs_with_calculated_gradient(),
            bfgs_with_numerical_gradient(),
            nm(),
            control = list(warmup = 2),
            unit = "s",
            times = reps,
            setup = setup_time_benchmark()
        )

        save_time_statistics(benchmark_results)
    }
}


benchmark_time()
