########## counterfactuals: An R Package for Counterfactual Explanation Methods
########## Replication script 
########## ----- Appendix -----

# Load packages
pkgs = c("counterfactuals", "iml",
         "caret", "tidymodels", "mlr", "rpart") # Appendix B.3.

rpkgs = sapply(pkgs, require, character.only = TRUE)
if (!all(rpkgs))
  sapply(pkgs[!rpkgs], install.packages)

rpkgs = sapply(pkgs, require, character.only = TRUE)

if (!all(rpkgs))
  stop("could not attach all required packages")

############## Appendix B.3. Different machine learning interfaces ##############
pkgs = c("caret", "tidymodels", "mlr", "rpart") 
rpkgs = sapply(pkgs, require, character.only = TRUE)
if (!all(rpkgs))
  sapply(pkgs[!rpkgs], install.packages)
rpkgs = sapply(pkgs, require, character.only = TRUE)
if (!all(rpkgs))
  stop("could not attach all required packages")

data("plasma", package = "gamlss.data")
x_interest = plasma[100L,]

## caret
library("caret")
treecaret = caret::train(retplasma ~ ., data = plasma[-100L,],
                         method = "rpart", tuneGrid = data.frame(cp = 0.01))
predcaret = Predictor$new(model = treecaret, data = plasma[-100L,],
                          y = "retplasma")
predcaret$predict(x_interest)

nicecaret = NICERegr$new(predcaret, optimization = "proximity",
                         margin_correct = 0.5, return_multiple = FALSE)
nicecaret$find_counterfactuals(x_interest,
                               desired_outcome = c(500, Inf))

# tidymodels
library("tidymodels")
treetm = decision_tree(mode = "regression", engine = "rpart") %>%
  fit(retplasma ~ ., data = plasma[-100L,])
predtm = Predictor$new(model = treetm, data = plasma[-100L,],
                       y = "retplasma")
predtm$predict(x_interest)
nicetm = NICERegr$new(predtm, optimization = "proximity",
                      margin_correct = 0.5, return_multiple = FALSE)
nicetm$find_counterfactuals(x_interest = x_interest,
                            desired_outcome = c(500, Inf))


# mlr
library("mlr")
task = mlr::makeRegrTask(data = plasma[-100L,], target = "retplasma")
mod = mlr::makeLearner("regr.rpart")
treemlr = mlr::train(mod, task)
predmlr = Predictor$new(model = treemlr, data = plasma[-100L,],
                        y = "retplasma")
predmlr$predict(x_interest)
nicemlr = NICERegr$new(predmlr, optimization = "proximity",
                       margin_correct = 0.5, return_multiple = FALSE)
nicemlr$find_counterfactuals(x_interest = x_interest,
                             desired_outcome = c(500, Inf))

# rpart
library("rpart")
treerpart = rpart(retplasma ~ ., data = plasma[-100L,])
predrpart = Predictor$new(model = treerpart, data = plasma[-100L,],
                          y = "retplasma")
predrpart$predict(x_interest)
nicerpart = NICERegr$new(predrpart, optimization = "proximity",
                         margin_correct = 0.5, return_multiple = FALSE)
nicerpart$find_counterfactuals(x_interest = x_interest,
                               desired_outcome = c(500, Inf))
