########## counterfactuals: An R Package for Counterfactual Explanation Methods
########## Replication script
########## ----- Main -----

# Load packages
pkgs = c("counterfactuals", "iml", 
         "rchallenge", "gamlss.data", "randomForest", "mlr3", # Section 4
         "R6") # Section 5

rpkgs = sapply(pkgs, require, character.only = TRUE)
if (!all(rpkgs))
  sapply(pkgs[!rpkgs], install.packages)

if(!require("featureTweakR")) { 
  devtools::install_github("katokohaku/featureTweakR")
}
if (!require("pforeach")) {
  devtools::install_github("hoxo-m/pforeach")
}

rpkgs = sapply(c(pkgs, "featureTweakR", "pforeach"), require, character.only = TRUE)

if (!all(rpkgs))
  stop("could not attach all required packages")

############ MOC applied to a classification task ############
## Load data --
data("german", package = "rchallenge")
credit = german[, c("duration", "amount", "purpose", "age",
  "employment_duration", "housing", "number_credits", "credit_risk")]

## Train rf model --
set.seed(20210816)
rf = randomForest(credit_risk ~ ., data = credit[-998L,])

# Set up Predictor
predictor = iml::Predictor$new(rf, type = "prob")
x_interest = credit[998L, ]
predictor$predict(x_interest)


moc_classif = MOCClassif$new(
  predictor, epsilon = 0, fixed_features = c("age", "employment_duration"),
  termination_crit = "genstag", n_generations = 10L, quiet = TRUE)

cfactuals = moc_classif$find_counterfactuals(
  x_interest, desired_class = "good", desired_prob = c(0.6, 1))

## Counterfactuals object --
print(cfactuals)
head(cfactuals$predict(), 3L)
head(cfactuals$evaluate(show_diff = TRUE, measures = c("dist_x_interest",
  "dist_target", "no_changed", "dist_train")), 3L)

## subset to valid --
cfactuals$subset_to_valid()
nrow(cfactuals$data)

# Plot relative frequency of changes
cfactuals$plot_freq_of_feature_changes(subset_zero = TRUE)

# Parallel plot
library(GGally)
cfactuals$plot_parallel(feature_names = names(
  cfactuals$get_freq_of_feature_changes()),  digits_min_max = 2L)

# Surface plot
cfactuals$plot_surface(feature_names = c("duration", "amount"))

## MOC diagnostics ---

# Mean and HV plots
moc_classif$plot_statistics(centered_obj = TRUE)

# Search plots
moc_classif$plot_search(objectives = c("dist_train", "dist_target"))
moc_classif$plot_search(objectives = c("dist_x_interest", "dist_train"))


############ NICE applied to a regression task ############
# Load data
data("plasma", package = "gamlss.data")

# Define xinterest
x_interest = plasma[100L,]


# Train rf model
tsk = mlr3::TaskRegr$new(id = "plasma", backend = plasma[-100L,],
  target = "retplasma")
tree = lrn("regr.rpart")
model = tree$train(tsk)
predictor = Predictor$new(model, data = plasma, y = "retplasma")
predictor$predict(x_interest)

# Call NICE method
nice_reg = NICERegr$new(predictor, optimization = "proximity",
  margin_correct = 0.5, return_multiple = FALSE)
cfactuals = nice_reg$find_counterfactuals(x_interest,
  desired_outcome = c(500, Inf))

# Surface plot
cfactuals$plot_surface(feature_names = c("betaplasma", "age"), grid_size = 200)

# User-defined distance function 
l0_norm = function(x, y, data) {
  res = matrix(NA, nrow = nrow(x), ncol = nrow(y))
  for (i in seq_len(nrow(x))) {
    for (j in seq_len(nrow(y))) {
      res[i, j] = sum(x[i,] != y[j,])
    }
  }
  res
}
xt = data.frame(a = c(0.5), b = c("a"))
yt = data.frame(a = c(0.5, 3.2, 0.1), b = c("a", "b", "a"))
l0_norm(xt, yt, data = NULL)

nice_reg = NICERegr$new(predictor, optimization = "proximity",
  margin_correct = 0.5, return_multiple = FALSE,
  distance_function = l0_norm)
nice_reg$find_counterfactuals(x_interest, desired_outcome = c(500, Inf))


############## Extension of the package ##############
# Create R6 Class
FeatureTweakerClassif = R6Class("FeatureTweakerClassif",
  inherit = CounterfactualMethodClassif,
  
  public = list(
    initialize = function(predictor, ktree = NULL, epsiron = 0.1,
      resample = FALSE) {
      super$initialize(predictor)
      private$ktree = ktree
      private$epsiron = epsiron
      private$resample = resample
    }
  ),
  
  private = list(
    ktree = NULL,
    epsiron = NULL,
    resample = NULL,
    
    run = function() {
      predictor = private$predictor
      y_hat_interest = predictor$predict(private$x_interest)
      class_x_interest = names(y_hat_interest)[which.max(y_hat_interest)]
      rf = predictor$model
      
      rules = getRules(rf, ktree = private$ktree, 
        resample = private$resample)
      es = set.eSatisfactory(rules, epsiron = private$epsiron)
      tweaks = featureTweakR::tweak(
        es, rf, private$x_interest, label.from = class_x_interest,
        label.to = private$desired_class, .dopar = FALSE
      )
      return(tweaks$suggest)
    },
    
    print_parameters = function() {
      cat(" - epsiron: ", private$epsiron, "\n")
      cat(" - ktree: ", private$ktree, "\n")
      cat(" - resample: ", private$resample)
    }
  )
)

# Iris data example
set.seed(78546)
X = subset(iris, select = -Species)[-150L,]
y = iris$Species[-150L]
rf = randomForest(X, y, ntree = 20L)
predictor = iml::Predictor$new(rf, data = iris[-150L, ], y = "Species", 
  type = "prob")

x_interest = iris[150L, ]
predictor$predict(x_interest)

ft_classif = FeatureTweakerClassif$new(predictor, ktree = 20L, resample = TRUE)
cfactuals = ft_classif$find_counterfactuals(x_interest = x_interest, 
  desired_class = "versicolor", desired_prob = c(0.6, 1))

