# ------------------------------
# Preliminaries, load packages
# ------------------------------
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# from abcpy.backends import BackendMPI as Backend
# the above is in case you want to use MPI, with `mpirun -n <number tasks> python code.py`
from abcpy.backends import BackendDummy as Backend
from abcpy.output import Journal
from abcpy.perturbationkernel import DefaultKernel
from abcpy.probabilisticmodels import ProbabilisticModel, Continuous, InputConnector
from abcpy.statistics import Identity, Statistics
from abcpy.statisticslearning import SemiautomaticNN

os.makedirs("Results", exist_ok = True)

## To fully reproduce the interim results saved in Results/*, please delete/rename the
## folder.

# ------------------------------
# Stochastic Lorenz Model
# ------------------------------

class StochLorenz95(ProbabilisticModel, Continuous):
    """Generates time dependent 'slow' weather variables following forecast model of Wilks [1],
        a stochastic reparametrization of original Lorenz model Lorenz [2].

        [1] Wilks, D. S. (2005). Effects of stochastic parametrizations in the lorenz 96 system.
        Quarterly Journal of the Royal Meteorological Society, 131(606), 389–407.

        [2] Lorenz, E. (1995). Predictability: a problem partly solved. In Proceedings of the
        Seminar on Predictability, volume 1, pages 1–18. European Center on Medium Range
        Weather Forecasting, Europe

        Parameters
        ----------
        parameters: list
            Contains the probabilistic models and hyperparameters from which the model derives.
        initial_state: numpy.ndarray, optional
            Initial state value of the time-series, The default value is None, which assumes a previously computed
            value from a full Lorenz model as the Initial value.
        n_timestep: int, optional
            Number of steps for a time between [0,4], where 4 corresponds to 20 days. The default value is 160 steps.
        """

    def __init__(self, parameters, F = 10, b = 10, h = 1, c = 4, J = 8, name = "StochLorenz95"):
        # Other parameters kept fixed
        # Assign initial state
        self.initial_state = np.array([6.4558, 1.1054, -1.4502, -0.1985, 1.1905, 2.3887, 5.6689, 6.7284, 0.9301,
                                       4.4170, 4.0959, 2.6830, 4.7102, 2.5614, -2.9621, 2.1459, 3.5761, 8.1188,
                                       3.7343, 3.2147, 6.3542, 4.5297, -0.4911, 2.0779, 5.4642, 1.7152, -1.2533,
                                       4.6262, 8.5042, 0.7487, -1.3709, -0.0520, 1.3196, 10.0623, -2.4885, -2.1007,
                                       3.0754, 3.4831, 3.5744, 6.5790])
        self.F = F
        # self.sigma_e = 1
        # self.phi = 0.4

        # define the parameters of the true model:
        self.b = b
        self.h = h
        self.c = c  # 10 is also used sometimes; it corresponds to the easy case, while c = 4 is harder.
        self.J = J
        self.hc_over_b = self.h * self.c / self.b
        self.cb = self.c * self.b

        # We expect input of type parameters = [theta1, theta2, n_timestep]
        if not isinstance(parameters, list):
            raise TypeError("Input of StochLorenz95 model is of type list")

        if len(parameters) != 5:
            raise RuntimeError("Input list must be of length 5, containing [theta1, theta2, sigma_e, phi, n_timestep].")

        input_connector = InputConnector.from_list(parameters)
        super().__init__(input_connector, name)

    def _check_input(self, input_values):
        # Check whether input has correct type or format
        if len(input_values) != 5:
            raise ValueError("Number of parameters of StochLorenz95 model must be 5.")

        # Check whether input is from correct domain
        theta1 = input_values[0]
        theta2 = input_values[1]
        sigma_e = input_values[2]
        phi = input_values[3]
        n_timestep = input_values[4]

        # if theta1 <= 0 or theta2 <= 0:
        # why? this does not make much sense, the parameters of the deterministic part could be smaller than 0
        #    return False

        if n_timestep <= 0 or sigma_e < 0 or phi < 0 or phi > 1:
            return False

        return True

    def _check_output(self, values):
        if not isinstance(values[0], np.ndarray):
            raise ValueError("Output of the normal distribution is always a number.")
        return True

    def get_output_dimension(self):
        return 1

    def forward_simulate(self, input_values, k, rng = np.random.RandomState()):
        # Extract the input parameters
        theta1 = input_values[0]
        theta2 = input_values[1]
        sigma_e = input_values[2]
        phi = input_values[3]
        n_timestep = input_values[4]

        # Do the actual forward simulation
        vector_of_k_samples = self.Lorenz95(theta1, theta2, sigma_e, phi, n_timestep, k)
        # Format the output to obey API
        result = [np.array([x]) for x in vector_of_k_samples]
        return result

    def forward_simulate_true_model(self, n_timestep, k, rng = np.random.RandomState()):
        # Do the actual forward simulation
        vector_of_k_samples_x, vector_of_k_samples_y = self.Lorenz95True(n_timestep, k, rng = rng)
        # Format the output to obey API
        result_x = [np.array([x]) for x in vector_of_k_samples_x]
        result_y = [np.array([y]) for y in vector_of_k_samples_y]
        return result_x, result_y

    def Lorenz95(self, theta1, theta2, sigma_e, phi, n_timestep, k):

        rng = np.random.RandomState()

        # Generate n_simulate time-series of weather variables satisfying Lorenz 95 equations
        result = []

        # Initialize timesteps.
        time_steps = np.linspace(0, 4, n_timestep)

        for k in range(0, k):
            # Define a parameter object containing parameters which is needed
            # to evaluate the ODEs
            # Stochastic forcing term
            eta = sigma_e * np.sqrt(1 - pow(phi, 2)) * rng.normal(0, 1, self.initial_state.shape[0])

            # Initialize the time-series
            timeseries = np.zeros(shape = (self.initial_state.shape[0], n_timestep), dtype = float)
            timeseries[:, 0] = self.initial_state
            # Compute the timeseries for each time steps
            for ind in range(0, n_timestep - 1):
                # parameters to be supplied to the ODE solver
                parameter = [eta, np.array([theta1, theta2])]
                # Each timestep is computed by using a 4th order Runge-Kutta solver
                x = self._rk4ode(self._l95ode_par, np.array([time_steps[ind], time_steps[ind + 1]]), timeseries[:, ind],
                                 parameter)
                timeseries[:, ind + 1] = x[:, -1]
                # Update stochastic forcing term
                eta = phi * eta + sigma_e * np.sqrt(1 - pow(phi, 2)) * rng.normal(0, 1)
            result.append(timeseries.flatten())
        # return an array of objects of type Timeseries
        return result

    def Lorenz95True(self, n_timestep, k, rng = np.random.RandomState()):
        """"Note that here there is randomness in the choice of the starting value of the y variables. I chose them to
        be uniform in [0,1] at the beginning. """
        # TODO: should implement a better ode solver for the true model!

        # Generate n_simulate time-series of weather variables satisfying Lorenz 95 equations
        result_X = []
        result_Y = []

        # Initialize timesteps
        # it is better to use a smaller timestep for the true model, as the solver may diverge otherwise.
        time_steps = np.linspace(0, 4, n_timestep)

        # define the initial state of the Y variables. We take self.J fast variables per slow variable
        self.initial_state_Y = rng.uniform(size = (40 * self.J))

        for k in range(0, k):
            # Define a parameter object containing parameters which is needed
            # to evaluate the ODEs

            # Initialize the time-series
            timeseries_X = np.zeros(shape = (self.initial_state.shape[0], n_timestep), dtype = float)
            timeseries_X[:, 0] = self.initial_state

            timeseries_Y = np.zeros(shape = (self.initial_state_Y.shape[0], n_timestep), dtype = float)
            timeseries_Y[:, 0] = self.initial_state_Y

            # Compute the timeseries for each time steps
            # the loop would not be needed if we wrote a single ode function for both set of variables.
            for ind in range(0, n_timestep - 1):
                # Each timestep is computed by using a 4th order Runge-Kutta solver
                x = self._rk4ode(self._l95ode_true_X, np.array([time_steps[ind], time_steps[ind + 1]]),
                                 timeseries_X[:, ind],
                                 [timeseries_Y[:,
                                  ind]])  # we pass the value of the other set of variables as the parameter.
                y = self._rk4ode(self._l95ode_true_Y, np.array([time_steps[ind], time_steps[ind + 1]]),
                                 timeseries_Y[:, ind],
                                 [timeseries_X[:,
                                  ind]])  # we pass the value of the other set of variables as the parameter.

                timeseries_X[:, ind + 1] = x[:, -1]
                timeseries_Y[:, ind + 1] = y[:, -1]

            result_X.append(timeseries_X.flatten())
            result_Y.append(timeseries_Y.flatten())

        # return an array of objects of type Timeseries
        return result_X, result_Y

    def _l95ode_par(self, t, x, parameter):
        """
        The parameterized two-tier lorenz 95 system defined by a set of symmetric
        ordinary differential equation. This function evaluates the differential
        equations at a value x of the time-series

        Parameters
        ----------
        x: numpy.ndarray of dimension px1
            The value of timeseries where we evaluate the ODE
        parameter: Python list
            The set of parameters needed to evaluate the function
        Returns
        -------
        numpy.ndarray
            Evaluated value of the ode at a fixed timepoint
        """
        # Initialize the array containing the evaluation of ode
        dx = np.zeros(shape = (x.shape[0],))
        eta = parameter[0]
        theta = parameter[1]
        # Deterministic parameterization for fast weather variables
        # ---------------------------------------------------------
        # assumed to be polynomial, degree of the polynomial same as the
        # number of columns in closure parameter
        degree = theta.shape[0]
        X = np.ones(shape = (x.shape[0], 1))
        for ind in range(1, degree):
            X = np.column_stack((X, pow(x, ind)))

        # deterministic reparametrization term
        # ------------------------------------
        gu = np.sum(X * theta, 1)

        # ODE definition of the slow variables
        # ------------------------------------
        dx[0] = -x[-2] * x[-1] + x[-1] * x[1] - x[0] + self.F - gu[0] + eta[0]
        dx[1] = -x[-1] * x[0] + x[0] * x[2] - x[1] + self.F - gu[1] + eta[1]
        for ind in range(2, x.shape[0] - 2):
            dx[ind] = -x[ind - 2] * x[ind - 1] + x[ind - 1] * x[ind + 1] - x[ind] + self.F - gu[ind] + eta[ind]
        dx[-1] = -x[-3] * x[-2] + x[-2] * x[1] - x[-1] + self.F - gu[-1] + eta[-1]

        return dx

    def _l95ode_true_X(self, t, x, parameter):
        """
        The equations for the x variables in the true two-tier lorenz 95 system defined by
        a set of symmetric ordinary differential equation. This function evaluates the
        differential equations at a value x of the time-series, given the corresponding value of
        the y variables.

        Parameters
        ----------
        x: numpy.ndarray of dimension px1
            The value of timeseries where we evaluate the ODE
        parameter: Python list
            It is a list with a single element that is the value of the y variables.
        Returns
        -------
        numpy.ndarray
            Evaluated value of the ode at a fixed timepoint
        """
        # Initialize the array containing the evaluation of ode
        dx = np.zeros(shape = (x.shape[0],))
        y = parameter[0]

        # ODE definition of the slow variables
        # ------------------------------------

        dx[0] = -x[-2] * x[-1] + x[-1] * x[1] - x[0] + self.F - self.hc_over_b * np.sum(y[0: self.J])
        dx[1] = -x[-1] * x[0] + x[0] * x[2] - x[1] + self.F - self.hc_over_b * np.sum(y[self.J: 2 * self.J])
        for ind in range(2, x.shape[0] - 2):
            dx[ind] = -x[ind - 2] * x[ind - 1] + x[ind - 1] * x[ind + 1] - x[ind] + self.F - \
                      self.hc_over_b * np.sum(y[self.J * ind: self.J * (ind + 1)])
        dx[-1] = -x[-3] * x[-2] + x[-2] * x[1] - x[-1] + self.F - self.hc_over_b * np.sum(y[-1 * self.J:])

        return dx

    def _l95ode_true_Y(self, t, y, parameter):
        """
        The equations for the y variables in the true two-tier lorenz 95 system defined by
        a set of symmetric ordinary differential equation. This function evaluates the
        differential equations at a value y of the time-series, given the corresponding value of
        the x variables.

        Parameters
        ----------
        y: numpy.ndarray of dimension px1
            The value of timeseries where we evaluate the ODE
        parameter: Python list
            It is a list with a single element that is the value of the x variables.
        Returns
        -------
        numpy.ndarray
            Evaluated value of the ode at a fixed timepoint
        """
        # Initialize the array containing the evaluation of ode
        dy = np.zeros(shape = (y.shape[0],))
        x = parameter[0]

        # ODE definition of the fast variables
        # ------------------------------------
        for ind in range(y.shape[0] - 3):
            dy[ind] = - self.cb * y[ind + 1] * (y[ind + 2] - y[ind - 1]) - self.c * y[ind] + \
                      self.hc_over_b * x[ind // self.J]  # // for the integer division.

        dy[-2] = - self.cb * y[- 1] * (y[0] - y[-3]) - self.c * y[-2] + \
                 self.hc_over_b * x[-2 // self.J]  # // for the integer division.

        dy[-1] = - self.cb * y[0] * (y[1] - y[-2]) - self.c * y[-1] + \
                 self.hc_over_b * x[-1 // self.J]  # // for the integer division.

        return dy

    def _rk4ode(self, ode, timespan, timeseries_initial, parameter):
        """
        4th order runge-kutta ODE solver.

        Parameters
        ----------
        ode: function
            The function defining Ordinary differential equation
        timespan: numpy.ndarray
            A numpy array containing the timepoints where the ode needs to be solved.
            The first time point corresponds to the initial value
        timeseries_initial: np.ndarray of dimension px1
            Intial value of the time-series, corresponds to the first value of timespan
        parameter: Python list
            The parameters needed to evaluate the ode
        Returns
        -------
        np.ndarray
            Timeseries initiated at timeseries_init and satisfying ode solved by this solver.
        """

        timeseries = np.zeros(shape = (timeseries_initial.shape[0], timespan.shape[0]))
        timeseries[:, 0] = timeseries_initial

        for ind in range(0, timespan.shape[0] - 1):
            time_diff = timespan[ind + 1] - timespan[ind]
            time_mid_point = timespan[ind] + time_diff / 2
            k1 = time_diff * ode(timespan[ind], timeseries_initial, parameter)
            k2 = time_diff * ode(time_mid_point, timeseries_initial + k1 / 2, parameter)
            k3 = time_diff * ode(time_mid_point, timeseries_initial + k2 / 2, parameter)
            k4 = time_diff * ode(timespan[ind + 1], timeseries_initial + k3, parameter)
            timeseries_initial = timeseries_initial + (k1 + 2 * k2 + 2 * k3 + k4) / 6
            timeseries[:, ind + 1] = timeseries_initial
        # Return the solved timeseries at the values in timespan
        return timeseries


# ------------------------------
# Different Neural networks used
# ------------------------------

class PhiNetwork(nn.Module):
    def __init__(self):
        super(PhiNetwork, self).__init__()
        # put some fully connected layers:
        self.fc1 = nn.Linear(80, 160)
        self.fc2 = nn.Linear(160, 80)
        self.fc3 = nn.Linear(80, 40)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class RhoNetwork(nn.Module):
    def __init__(self, n_parameters):
        super(RhoNetwork, self).__init__()
        # put some fully connected layers:
        self.fc1 = nn.Linear(80, 120)
        self.fc2 = nn.Linear(120, 120)
        self.fc3 = nn.Linear(120, 50)
        self.fc4 = nn.Linear(50, n_parameters)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x


class PEN1(nn.Module):
    """Implementation of the Partially Echangeable Network from [1].

    [1] Wiqvist, S., Mattei, P.A., Picchini, U. and Frellsen, J., 2019. Partially Exchangeable Networks and
    Architectures for Learning Summary Statistics in Approximate Bayesian Computation. arXiv preprint arXiv:1901.10230.
    """

    def __init__(self, phi_net, rho_net, n_timestep = 120):
        """The only constraints on phi_net, rho_net are that:
        - phi_net has 80 input neurons
        - rho_net has 40 + output of phi_net input neurons
        - rho_net has 2 output neurons (the number of parameters to be estimated)"""
        super(PEN1, self).__init__()
        self.phi_net = phi_net
        self.rho_net = rho_net
        self.n_timestep = n_timestep

    def forward(self, x):
        """x should be a tensor of shape (n_samples, 40 * timestep), that is the size of the output of the Lorenz model.
            Therefore, the network should reshape it back into the correct shape.
        """
        if x.dim() != 2:
            raise RuntimeError("The input must have 2 dimensions.")

        x = x.reshape(x.shape[0], -1, self.n_timestep)
        # in this way it becomes (n_samples, 40 * timestep)

        # format the stuff so that you can apply it to the inner network.
        input_inner = torch.cat((x[:, :, 0:self.n_timestep - 1], x[:, :, 1:self.n_timestep]), 1).transpose(2, 1)
        # print(input_inner.shape)

        # apply inner network
        output_inner = self.phi_net(input_inner)
        # print(output_inner.shape)

        # sum along timesteps
        output_inner = torch.sum(output_inner, dim = 1)
        # print(output_inner.shape)

        # concatenate each sample to the first timestep
        input_outer = torch.cat((output_inner, x[:, :, 0]), dim = 1)
        # print(input_outer.shape)

        # apply outer network:
        output_outer = self.rho_net(input_outer)
        # print(output_outer.shape)

        return output_outer


# ------------------------------
# Hakkarainen summary statistics
# ------------------------------

class HakkarainenLorenzStatistics(Statistics):
    """
    This class implements the statistics function from the Statistics protocol. This
    extracts the statistics following Hakkarainen et. al. [1] from the multivariate timesereis
    generated by solving Lorenz 95 odes.

    [1] J. Hakkarainen, A. Ilin, A. Solonen, M. Laine, H. Haario, J. Tamminen, E. Oja, and
    H. Järvinen. On closure parameter estimation in chaotic systems. Nonlinear Processes
    in Geophysics, 19(1):127–143, Feb. 2012.
    """

    def __init__(self, degree = 2, cross = True):
        self.degree = degree
        self.cross = cross

    def statistics(self, data):
        if isinstance(data, list):
            if np.array(data).shape == (len(data),):
                if len(data) == 1:
                    data = np.array(data).reshape(1, 1)
                data = np.array(data).reshape(len(data), 1)
            else:
                data = np.concatenate(data).reshape(len(data), -1)
        else:
            raise TypeError("Input data should be of type list, but found type {}".format(type(data)))
        ## Extract Hakkarainen Summary Statistics
        num_element, timestep = len(data), int(data[0].shape[0] / 40)
        result = np.zeros(shape = (num_element, 6))
        # Compute statistics
        for ind_element in range(0, num_element):
            # First convert the vector to the 40 dimensional timeseries
            data_ind_element = data[ind_element].reshape(40, timestep)
            # Extract Mean
            s1 = np.mean(np.mean(data_ind_element, 1))
            # Variance
            s2 = np.mean(np.var(data_ind_element, 1))
            ## Extract Auto Covariance with lag 1
            s3 = 0.0
            for ind in range(0, data_ind_element.shape[0]):
                s3 += self._auto_covariance(data_ind_element[ind, :], lag = 1)
            s3 = s3 / data_ind_element.shape[0]
            ## Extract Covariance with a neighboring node
            s4 = 0.0
            for ind in range(0, data_ind_element.shape[0] - 1):
                s4 += np.mean(data_ind_element[ind, :] * data_ind_element[ind + 1, :]) \
                      - np.mean(data_ind_element[ind, :]) * np.mean(data_ind_element[ind + 1, :])
            s4 = s4 / data_ind_element.shape[0]
            ## Extract Cross-Cov with 2 neighbors with time lag 1
            s5 = 0.0
            s6 = self._cross_covariance(data_ind_element[1, :], data_ind_element[2, :])
            for ind in range(1, data_ind_element.shape[0] - 1):
                s5 += self._cross_covariance(data_ind_element[ind, :], data_ind_element[ind - 1, :])
                s6 += self._cross_covariance(data_ind_element[ind, :], data_ind_element[ind + 1, :])
            s5 += 0.0
            s6 += self._cross_covariance(data_ind_element[-2, :], data_ind_element[-1, :])
            s5 = s5 / data_ind_element.shape[0]
            s6 = s6 / data_ind_element.shape[0]

            result[ind_element, :] = [s1, s2, s3, s4, s5, s6]

            # Expand the data with polynomial expansion
        result = self._polynomial_expansion(result)
        return np.array(result)

    def _cross_covariance(self, x, y):
        """ Computes cross-covariance between x and y

        Parameters
        ----------
        x: numpy.ndarray
            Vector of real numbers.
        y: numpy.ndarray
            Vector of real numbers.

        Returns
        -------
        numpy.ndarray
            Cross-covariance calculated between x and y.
        """
        return np.mean(np.insert(x, 0, 1) * np.insert(y, -1, 1)) - np.mean(np.insert(x, 0, 1)) * np.mean(
            np.insert(y, -1, 1))

    def _auto_covariance(self, x, lag = 1):
        """
        Calculate the autocovarriance coefficient of x with lag k.

        Parameters
        ----------
        x: numpy.ndarray
            Vector of real numbers.
        k: integer
            Time-lag.

        Returns
        -------
        numpy.ndarray
            Returns the auto-covariance of x with time-lag k.
        """

        N = x.shape[0]
        x_mean = np.average(x)

        autoCov = 0.0
        for ind in range(0, N - lag):
            autoCov += (x[ind + lag] - x_mean) * (x[ind] - x_mean)
        return (1 / (N - 1)) * autoCov


# ----------------------------
#
# Section 3
#
# ----------------------------

# ------------------------------------------------------------------------------------------
# Here we do difference inference tasks for Lorenz model and create corresponding figures
# ------------------------------------------------------------------------------------------

# Section 3
## Define Graphical Model
from abcpy.continuousmodels import Uniform
theta1 = Uniform([[0.5], [3.5]], name = "theta1")
theta2 = Uniform([[0], [0.3]], name = "theta2")
sigma_e = 1; phi = 0.4; T = 1024
lorenz = StochLorenz95([theta1, theta2, sigma_e, phi, T], name = "lorenz")


## define the Hakkarainen statistic
statistics_calculator = HakkarainenLorenzStatistics(degree = 1, cross = False)

# Define distance
from abcpy.distances import Euclidean
distance_calculator = Euclidean(statistics_calculator)

# Define perturbation kernel
from abcpy.perturbationkernel import DefaultKernel
kernel = DefaultKernel([theta1, theta2])

## Define backend
backend = Backend()


### Define sampler (here we illustrate 4 different sampling scheme: PMCABC, SABC, ABCsubsim and APMCABC)

## Generate a fake observation
true_parameter_values = [2, .1]
observation = lorenz.forward_simulate([2, .1, sigma_e, phi, T], 1, rng = np.random.RandomState(42))

try:
    journal = Journal.fromFile("Results/lorenz_hakkarainen_pmcabc.jrnl")
except FileNotFoundError:
    print("Run inference with PMCABC")
    sampler = PMCABC([lorenz], [distance_calculator], backend, kernel, seed = 1)
    # Define sampling parameters
    steps, n_samples, n_samples_per_param, full_output = 3, 10000, 1, 0
    eps_arr = np.array([500]); eps_percentile = 10
    


    # Sample
    journal = sampler.sample([observation], steps, eps_arr, n_samples,
                             n_samples_per_param, eps_percentile, full_output = full_output)
    # save the final journal file
    journal.save("Results/lorenz_hakkarainen_pmcabc.jrnl")

# print posterior mean and variance
print(journal.posterior_mean())
print(journal.posterior_cov())

# plot the posterior
journal.plot_posterior_distr(double_marginals_only = True, show_samples = False,
                             true_parameter_values = true_parameter_values,
                             path_to_save = "../Figures/lorenz_hakkarainen_pmcabc.pdf")


# ----------------------------
#
# Section 5
#
# ----------------------------


from abcpy.inferences import SABC, APMCABC, ABCsubsim

try:
    journal = Journal.fromFile("Results/lorenz_hakkarainen_sabc.jrnl")
except FileNotFoundError:
    print("Run with inference SABC")
    sampler = SABC([lorenz], [distance_calculator], backend, kernel, seed = 1)
    # Define sampling parameters
    steps, n_samples, n_samples_per_param, full_output = 20, 10000, 1, 0
    # steps, n_samples, n_samples_per_param, full_output = 2, 100, 1, 0 # quicker
    epsilon = 500
    # Sample
    journal = sampler.sample([observation], steps, epsilon, n_samples,
                             n_samples_per_param, full_output = full_output)
    # save the final journal file
    journal.save("Results/lorenz_hakkarainen_sabc.jrnl")
# plot the posterior
journal.plot_posterior_distr(double_marginals_only = True, show_samples = False,
                             true_parameter_values = true_parameter_values,
                             path_to_save = "../Figures/lorenz_hakkarainen_sabc.pdf")

try:
    journal = Journal.fromFile("Results/lorenz_hakkarainen_abcsubsim.jrnl")
except FileNotFoundError:
    print("Run inference with ABCsubsim")
    sampler = ABCsubsim([lorenz], [distance_calculator], backend, kernel, seed = 1)
    # Define sampling parameters
    steps, n_samples, n_samples_per_param, full_output = 20, 10000, 1, 1
    # steps, n_samples, n_samples_per_param, full_output = 2, 100, 1, 0 # quicker
    # Sample
    journal = sampler.sample([observation], steps, n_samples,
                             n_samples_per_param, full_output = full_output)
    # save the final journal file
    journal.save("Results/lorenz_hakkarainen_abcsubsim.jrnl")
# plot the posterior
journal.plot_posterior_distr(double_marginals_only = True, show_samples = False,
                             true_parameter_values = true_parameter_values,
                             path_to_save = "../Figures/lorenz_hakkarainen_abcsubsim.pdf")

try:
    journal = Journal.fromFile("Results/lorenz_hakkarainen_apmcabc.jrnl")
except FileNotFoundError:
    print("Run inference with APMCABC")
    sampler = APMCABC([lorenz], [distance_calculator], backend, kernel, seed = 1)
    # Define sampling parameters
    steps, n_samples, n_samples_per_param, full_output = 20, 10000, 1, 1
    # steps, n_samples, n_samples_per_param, full_output = 2, 100, 1, 1 # quicker
    acceptance_cutoff = 0.003
    # Sample
    journal = sampler.sample([observation], steps, n_samples,
                             n_samples_per_param, full_output = full_output)
    # save the final journal file
    journal.save("Results/lorenz_hakkarainen_apmcabc.jrnl")
# plot the posterior
journal.plot_posterior_distr(double_marginals_only = True, show_samples = False,
                             true_parameter_values = true_parameter_values,
                             path_to_save = "../Figures/lorenz_hakkarainen_apmcabc.pdf")



# ----------------------------
#
# Section 5 Figures for RUNTIMES
#
# ----------------------------
# ----------------------------------------------------------------------------------------------------
# Following we provide runtimes data for Lorenz runs on MPI, AWS etc. with different types of parallelization
#------------------------------------------------------------------------------------------------------------

import numpy as np
import matplotlib.pyplot as plt

# # dyn-mpi
# > PMCABC
pmcabc_dyn_mpi_raw = [
    [4419.175606, 2316.508076, 1281.6908, 613.1040618, 424.8041892, 333.3324294, 265.8309147, 312.2090373],
    [5075.024077, 2488.037973, 1307.257664, 714.9290788, 446.6826365, 372.9534166, 357.4033408, 305.1729059],
    [5048.320467, 2439.768328, 1346.55576, 711.4803486, 467.7459741, 269.6679242, 266.5329554, 253.5670722],
    [3926.589782, 1967.703535, 1044.190016, 594.2274213, 389.304249, 288.7576969, 265.8936141, 298.4229198],
    [4955.313391, 2587.397302, 1419.184933, 707.9941218, 431.6817532, 409.5818915, 298.1922028, 330.3212276]]

pmcabc_runtimes_dyn_mpi = np.mean(pmcabc_dyn_mpi_raw, axis = 0)
pmcabc_num_nodes_dyn_mpi = np.array([2, 4, 8, 16, 32, 64, 128, 256])
pmcabc_num_cores_dyn_mpi = pmcabc_num_nodes_dyn_mpi * 36 - 1

pmcabc_speedup_dyn_mpi = pmcabc_runtimes_dyn_mpi[0] / pmcabc_runtimes_dyn_mpi
pmcabc_efficiency_dyn_mpi = pmcabc_speedup_dyn_mpi / pmcabc_num_cores_dyn_mpi
pmcabc_efficiency_dyn_mpi /= np.max(pmcabc_efficiency_dyn_mpi)

# > SABC
sabc_dyn_mpi_raw = [
    [378.864748, 191.5074463, 98.82132864, 59.75189757, 36.76258564, 25.19009137, 26.71395397, 58.38384438],
    [375.1854441, 189.3980782, 96.36304498, 50.72541738, 27.43537879, 16.87575555, 13.50409412, 30.28837824],
    [375.3394582, 189.3887148, 96.60298038, 50.72501993, 27.4957695, 16.93699384, 13.61744761, 30.78951073],
    [375.550678, 189.4977794, 96.61252856, 50.77751827, 27.4807148, 17.02190328, 13.57771993, 30.27911067],
    [375.5202515, 189.1094029, 96.46440983, 50.75433683, 27.47793436, 16.9701767, 13.57473397, 30.18758988]]

sabc_runtimes_dyn_mpi = np.mean(sabc_dyn_mpi_raw, axis = 0)
sabc_num_nodes_dyn_mpi = np.array([2, 4, 8, 16, 32, 64, 128, 256])
sabc_num_cores_dyn_mpi = sabc_num_nodes_dyn_mpi * 36 - 1

sabc_speedup_dyn_mpi = sabc_runtimes_dyn_mpi[0] / sabc_runtimes_dyn_mpi
sabc_efficiency_dyn_mpi = sabc_speedup_dyn_mpi / sabc_num_cores_dyn_mpi
sabc_efficiency_dyn_mpi /= np.max(sabc_efficiency_dyn_mpi)

# # MPI
# > PMCABC
pmcabc_runtimes_mpi_mc = np.array(
    [5701.86419301033, 3161.279069328308, 1896.5004853010178, 1068.7664306402207, 684.1660243988038])
pmcabc_num_nodes_mpi_mc = np.array([2, 4, 8, 16, 32])
pmcabc_num_cores_mpi_mc = pmcabc_num_nodes_mpi_mc * 36 - 1

pmcabc_speedup_mpi_mc = pmcabc_runtimes_mpi_mc[0] / pmcabc_runtimes_mpi_mc
pmcabc_efficiency_mpi_mc = pmcabc_speedup_mpi_mc / pmcabc_num_cores_mpi_mc
pmcabc_efficiency_mpi_mc /= np.max(pmcabc_efficiency_mpi_mc)

# > SABC
sabc_runtimes_mpi_mc = np.array(
    [410.6378993988037, 209.3689744234085, 106.65136272907257, 58.12886316776276, 33.83404488563538])
sabc_num_nodes_mpi_mc = np.array([2, 4, 8, 16, 32])
sabc_num_cores_mpi_mc = sabc_num_nodes_mpi_mc * 36 - 1

sabc_speedup_mpi_mc = sabc_runtimes_mpi_mc[0] / sabc_runtimes_mpi_mc
sabc_efficiency_mpi_mc = sabc_speedup_mpi_mc / sabc_num_cores_mpi_mc
sabc_efficiency_mpi_mc /= np.max(sabc_efficiency_mpi_mc)

# # MPI
# >PMCABC

pmcabc_runtimes_mpi_raw = np.array([[8754.386101, 3083.351378, 1202.855324, 725.9162221, 468.175307, 307.3471186],
                                    [8990.943574, 3453.375321, 1506.616017, 739.4082046, 583.8848307, 285.4516006],
                                    [9260.638548, 3384.452699, 1453.021729, 761.6351624, 483.9109256, 273.7452123],
                                    [7621.587835, 2767.080381, 1199.700564, 581.5635383, 399.0376515, 237.9761641]])

pmcabc_runtimes_mpi = np.mean(pmcabc_runtimes_mpi_raw, axis = 0)
pmcabc_num_nodes_mpi = np.array([3, 9, 25, 65, 129, 417])
pmcabc_num_cores_mpi = pmcabc_num_nodes_mpi * 12 - 1

pmcabc_speedups_mpi = pmcabc_runtimes_mpi[0] / pmcabc_runtimes_mpi
# speedups_mpi_mean = np.mean(speedups_mpi, axis = 0)

# EXCLUDING the last value to make the graphs easier to look at.
pmcabc_speedups_mpi = pmcabc_speedups_mpi[:-1]
pmcabc_num_nodes_mpi = pmcabc_num_nodes_mpi[:-1]
pmcabc_num_cores_mpi = pmcabc_num_cores_mpi[:-1]

pmcabc_efficiency_mpi = pmcabc_speedups_mpi / pmcabc_num_cores_mpi
pmcabc_efficiency_mpi /= np.max(pmcabc_efficiency_mpi)

# # Ideal Speedup

speedup_ideal = pmcabc_num_nodes_mpi / pmcabc_num_nodes_mpi[0]
# efficiency_ideal = speedup_ideal/num_cores_mpi

# # Spark AWS
# > PMCABC
pmcabc_runtimes_aws = np.array(
    [9437.556269550323, 4782.96914525032, 2460.2412045955657, 1324.7504865169526, 762.9563060760498])
pmcabc_num_nodes_aws = np.array([2, 4, 8, 16, 32])
pmcabc_num_cores_aws = pmcabc_num_nodes_aws * 36 - 1

pmcabc_speedups_aws = pmcabc_runtimes_aws[0] / pmcabc_runtimes_aws
pmcabc_efficiency_aws = pmcabc_speedups_aws / pmcabc_num_cores_aws
pmcabc_efficiency_aws /= np.max(pmcabc_efficiency_aws)

# # Spark Original Benchmarking (NTT = 5760)
# > PMCABC

pmcabc_runtimes_spark = np.array([1.25896000e+04, 6.16957143e+03, 3284.66666667, 1847.83333333, 1089.28571429])
pmcabc_num_nodes_spark = np.array([3, 5, 9, 17, 33]) - 1
pmcabc_num_cores_spark = pmcabc_num_nodes_spark * 36

pmcabc_speedups_spark = pmcabc_runtimes_spark[0] / pmcabc_runtimes_spark
pmcabc_efficiency_spark = pmcabc_speedups_spark / pmcabc_num_cores_spark
pmcabc_efficiency_spark /= np.max(pmcabc_efficiency_spark)

# > SABC
sabc_runtimes_spark1024_raw = [
    [834.1416621, 417.370055, 335.7132995, 116.2761066, 72.28094292, 53.2144928, 57.60567093, 84.10586905],
    [329.9083505, 409.7035534, 330.4003179, 109.0879076, 59.05881071, 35.67840767, 25.64293742, 23.33583498],
    [78.51307583, 409.5178506, 329.8868394, 108.8763447, 59.07009077, 35.70212293, 25.4355576, 23.20267606],
    [78.38705969, 409.58481, 329.7591956, 108.885432, 58.68075657, 36.20251465, 25.2525754, 23.35074496],
    [78.39588737, 409.5513053, 329.9188914, 108.8227527, 58.72610879, 35.26226592, 25.37326837, 23.48383713]]

sabc_runtimes_spark1024 = np.mean(sabc_runtimes_spark1024_raw, axis = 0)
sabc_num_nodes_spark1024 = np.array([2, 4, 8, 16, 32, 64, 128, 256])
sabc_num_cores_spark1024 = sabc_num_nodes_spark1024 * 36

sabc_speedups_spark1024 = sabc_runtimes_spark1024[0] / sabc_runtimes_spark1024
sabc_efficiency_spark1024 = sabc_speedups_spark1024 / sabc_num_cores_spark1024
sabc_efficiency_spark1024 /= np.max(sabc_efficiency_spark1024)

# ## Speedup & Efficiency Diagrams for PMCABC (Without dyn-mpi)
plt.figure(figsize = (8, 6))
plt.plot(pmcabc_num_cores_aws, pmcabc_speedups_aws, c = "black", linestyle = "solid", label = "Spark on AWS", marker = ".",
         markersize = 10)
plt.plot(pmcabc_num_cores_aws, pmcabc_speedups_spark, c = "black", linestyle = ":", label = "Spark on Daint", marker = ".",
         markersize = 10)

plt.plot(pmcabc_num_cores_mpi_mc, pmcabc_speedup_mpi_mc, c = "black", linestyle = "dashed", label = "MPI on Daint",
         marker = "*", markersize = 10)

plt.xlabel("n", fontsize = 30)
plt.ylabel(r"$\mathcal{S}_{\mathcal{A}}$", fontsize = 30)
plt.legend(loc = "best", frameon = False, numpoints = 1, fontsize = 15)
plt.tight_layout()
plt.xticks(pmcabc_num_cores_aws, fontsize = 15)
plt.yticks(fontsize = 15)
plt.savefig("../Figures/perf_speedup_spark_mpi.eps")

plt.figure(figsize = (8, 6))
plt.plot(pmcabc_num_cores_aws, pmcabc_efficiency_aws, c = "black", linestyle = "solid", label = "Spark on AWS", marker = ".",
         markersize = 10)
plt.plot(pmcabc_num_cores_aws, pmcabc_efficiency_spark, c = "black", linestyle = ":", label = "Spark on Daint", marker = ".",
         markersize = 10)

plt.plot(pmcabc_num_cores_mpi_mc, pmcabc_efficiency_mpi_mc, c = "black", linestyle = "dashed", label = "MPI on Daint",
         marker = "*", markersize = 10)

plt.plot(pmcabc_num_cores_aws, [0.5] * 5, c = "gray", linestyle = "dashed")

plt.xlabel("n", fontsize = 30)
plt.ylabel(r"$\mathcal{E}_{\mathcal{A}}$", fontsize = 30)
plt.legend(loc = "best", frameon = False, numpoints = 1, fontsize = 15)
plt.tight_layout()
plt.xticks(pmcabc_num_cores_aws, fontsize = 15)
plt.yticks(fontsize = 15)

plt.savefig("../Figures/perf_efficiency_spark_mpi.eps")

# Different methods performance
plt.figure(figsize = (8, 6))
cores = np.array([576, 1151, 2304, 4608, 9216])
cores_ticks = np.array([576, 2304, 4608, 9216])
efficiency_pmcabc = np.array([1, 0.790165, 0.537935, 0.31004, 0.1532775])
efficiency_apmcabc = np.array([1, 0.894945, 0.5597275, 0.358885, 0.2513225])
efficiency_abcsubsim = np.array([1, 0.7084275, 0.3608425, 0.1792225, 0.0882475])
efficiency_sabc = np.array([1, 0.777665, 0.5277575, 0.346985, 0.227905])

speedup_pmcabc = efficiency_pmcabc * cores / 576
speedup_apmcabc = efficiency_apmcabc * cores / 576
speedup_abcsubsim = efficiency_abcsubsim * cores / 576
speedup_sabc = efficiency_sabc * cores / 576

plt.plot(cores, speedup_pmcabc, c = "black", linestyle = "solid", label = "PMCABC", marker = "*", markersize = 10, lw = 1)
plt.plot(cores, speedup_apmcabc, c = "black", linestyle = "dotted", label = "APMCABC", marker = "*", markersize = 10, lw = 1)
plt.plot(cores, speedup_abcsubsim, c = "black", linestyle = "dashdot", label = "ABCsubsim", marker = "*", markersize = 10, lw = 1)
plt.plot(cores, speedup_sabc, c = "black", linestyle = "dashed", label = "SABC", marker = "*", markersize = 10, lw = 1)

plt.xlabel("n", fontsize = 30)
plt.ylabel(r"$\mathcal{S}_{\mathcal{A}}$", fontsize = 30)
plt.legend(loc = "best", frameon = False, numpoints = 1, fontsize = 15)
plt.tight_layout()
plt.xticks(cores_ticks, fontsize = 15)
plt.yticks([2, 4], fontsize = 15)
plt.ylim(1)
plt.savefig("../Figures/comparison_speedup.eps")

plt.figure(figsize = (8, 6))
plt.plot(cores, efficiency_pmcabc, c = "black", linestyle = "solid", label = "PMCABC", marker = "*", markersize = 10, lw = 1)
plt.plot(cores, efficiency_apmcabc, c = "black", linestyle = "dotted", label = "APMCABC", marker = "*", markersize = 10, lw = 1)
plt.plot(cores, efficiency_abcsubsim, c = "black", linestyle = "dashdot", label = "ABCsubsim", marker = "*", markersize = 10,
         lw = 1)
plt.plot(cores, efficiency_sabc, c = "black", linestyle = "dashed", label = "SABC", marker = "*", markersize = 10, lw = 1)

plt.plot(cores, [0.5] * 5, c = "gray", linestyle = "dashed")

plt.xlabel("n", fontsize = 30)
plt.ylabel(r"$\mathcal{E}_{\mathcal{A}}$", fontsize = 30)
plt.legend(loc = "best", frameon = False, numpoints = 1, fontsize = 15)
plt.tight_layout()
plt.xticks(cores_ticks, fontsize = 15)
plt.yticks(fontsize = 15)
plt.savefig("../Figures/comparison_efficiency.eps")

# # dyn-mpi performance
# > PMCABC

plt.figure(figsize = (8, 6))
plt.plot(pmcabc_num_cores_aws, pmcabc_speedups_spark, c = "black", linestyle = ":", label = "Spark on Daint", marker = ".",
         markersize = 10)

plt.plot(pmcabc_num_cores_mpi_mc, pmcabc_speedup_mpi_mc, c = "black", linestyle = "dashed", label = "MPI on Daint",
         marker = "*", markersize = 10)
plt.plot(pmcabc_num_cores_dyn_mpi[:-3], pmcabc_speedup_dyn_mpi[:-3], c = "black", linestyle = "-.",
         label = "dynamic-MPI on Daint", marker = "*", markersize = 10)

plt.xlabel("n", fontsize = 30)
plt.ylabel(r"$\mathcal{S}_{\mathcal{A}}$", fontsize = 30)
plt.legend(loc = "best", frameon = False, numpoints = 1, fontsize = 15)
plt.tight_layout()
plt.xticks(pmcabc_num_cores_aws, fontsize = 15)
plt.yticks(fontsize = 15)
plt.savefig("../Figures/perf_speedup_spark_mpi_dyn.eps")

plt.figure(figsize = (8, 6))
plt.plot(pmcabc_num_cores_aws, pmcabc_efficiency_spark, c = "black", linestyle = ":", label = "Spark on Daint", marker = ".",
         markersize = 10)

plt.plot(pmcabc_num_cores_mpi_mc, pmcabc_efficiency_mpi_mc, c = "black", linestyle = "dashed", label = "MPI on Daint",
         marker = "*", markersize = 10)
plt.plot(pmcabc_num_cores_dyn_mpi[:-3], pmcabc_efficiency_dyn_mpi[:-3], c = "black", linestyle = "-.",
         label = "dynamic-MPI on Daint", marker = "*", markersize = 10)

plt.plot(pmcabc_num_cores_aws, [0.5] * 5, c = "gray", linestyle = "dashed")

plt.xlabel("n", fontsize = 30)
plt.ylabel(r"$\mathcal{E}_{\mathcal{A}}$", fontsize = 30)
plt.legend(loc = "best", frameon = False, numpoints = 1, fontsize = 15)
plt.tight_layout()
plt.xticks(pmcabc_num_cores_aws, fontsize = 15)
plt.yticks(fontsize = 15)

plt.savefig("../Figures/perf_efficiency_spark_mpi_dyn.eps")


# ----------------------------
#
# Section 6.1
#
# ----------------------------


### Now we illustrate how to use new learned summary statistics


## define the statistic that will be applied before learning transformation
preprocessing_statistics = Identity(degree = 1, cross = False)

## define the neural net to be used. This is the Partially Exchangeable Network
try:
    journal = Journal.fromFile("Results/lorenz_learned_stats_sabc.jrnl")
except FileNotFoundError:
    phi_net = PhiNetwork()
    rho_net = RhoNetwork(n_parameters = 2)
    embedding_net = PEN1(phi_net, rho_net, n_timestep = T)

    print("Learn summary stats...")
    # Run now the SemiautomaticNN algorithm to learn the statistics
    summary_selection = SemiautomaticNN([lorenz], preprocessing_statistics,
                                        backend, embedding_net, n_samples = 500, seed = 12)

    # get the learned statistic
    statistics_calculator = summary_selection.get_statistics()

    # Re-define distance
    distance_calculator = Euclidean(statistics_calculator)

    print("Run inference with SABC")
    sampler = SABC([lorenz], [distance_calculator], backend, kernel, seed = 1)
    # Define sampling parameters
    steps, n_samples, n_samples_per_param, full_output = 20, 10000, 1, 0
    # steps, n_samples, n_samples_per_param, full_output = 2, 10, 1, 0 # quicker
    epsilon = 500
    # Sample
    journal = sampler.sample([observation], steps, epsilon, n_samples,
                             n_samples_per_param, full_output = full_output)
    # save the final journal file
    journal.save("Results/lorenz_learned_stats_sabc.jrnl")

# ----------------------------
# Figure 8
# ----------------------------

# plot the posterior
journal.plot_posterior_distr(double_marginals_only = True, show_samples = False,
                             true_parameter_values = true_parameter_values,
                             path_to_save = "../Figures/lorenz_learned_stats_sabc.pdf")

# ----------------------------
#
# Section 6.2
#
# ----------------------------
run_school_example = False

if run_school_example:
  grades_obs = [3.872486707973337, 4.6735380808674405, 3.9703538990858376, 4.11021272048805, 4.211048655421368,
                  4.154817956586653, 4.0046893064392695, 4.01891381384729, 4.123804757702919, 4.014941267301294,
                  3.888174595940634, 4.185275142948246, 4.55148774469135, 3.8954427675259016, 4.229264035335705,
                  3.839949451328312, 4.039402553532825, 4.128077814241238, 4.361488645531874, 4.086279074446419,
                  4.370801602256129, 3.7431697332475466, 4.459454162392378, 3.8873973643008255, 4.302566721487124,
                  4.05556051626865, 4.128817316703757, 3.8673704442215984, 4.2174459453805015, 4.202280254493361,
                  4.072851400451234, 3.795173229398952, 4.310702877332585, 4.376886328810306, 4.183704734748868,
                  4.332192463368128, 3.9071312388426587, 4.311681374107893, 3.55187913252144, 3.318878360783221,
                  4.187850500877817, 4.207923106081567, 4.190462065625179, 4.2341474252986036, 4.110228694304768,
                  4.1589891480847765, 4.0345604687633045, 4.090635481715123, 3.1384654393449294, 4.20375641386518,
                  4.150452690356067, 4.015304457401275, 3.9635442007388195, 4.075915739179875, 3.5702080541929284,
                  4.722333310410388, 3.9087618197155227, 4.3990088006390735, 3.968501165774181, 4.047603645360087,
                  4.109184340976979, 4.132424805281853, 4.444358334346812, 4.097211737683927, 4.288553086265748,
                  3.8668863066511303, 3.8837108501541007]


  from abcpy.continuousmodels import Uniform, Normal
  school_budget = Uniform([[1], [10]], name = "school_budget")
  class_size = Normal([[800*school_budget], [1]], name = "class_size")
  no_teacher = Normal([[20*school_budget], [1]], name = "no_teacher")
  historical_mean_grade = Normal([[4.5], [0.25]],
  name = "historical_mean_grade")

  final_grade = historical_mean_grade - .001 * class_size + .02 * no_teacher

  # ----------------------------
  #
  # Section 6.3 
  #
  # ----------------------------
  scholarship_obs = [2.7179657436207805, 2.124647285937229, 3.07193407853297, 2.335024761813643, 2.871893855192,
                       3.4332002458233837, 3.649996835818173, 3.50292335102711, 2.815638168018455, 2.3581613289315992,
                       2.2794821846395568, 2.8725835459926503, 3.5588573782815685, 2.26053126526137, 1.8998143530749971,
                       2.101110815311782, 2.3482974964831573, 2.2707679029919206, 2.4624550491079225, 2.867017757972507,
                       3.204249152084959, 2.4489542437714213, 1.875415915801106, 2.5604889644872433, 3.891985093269989,
                       2.7233633223405205, 2.2861070389383533, 2.9758813233490082, 3.1183403287267755,
                       2.911814060853062, 2.60896794303205, 3.5717098647480316, 3.3355752461779824, 1.99172284546858,
                       2.339937680892163, 2.9835630207301636, 2.1684912355975774, 3.014847335983034, 2.7844122961916202,
                       2.752119871525148, 2.1567428931391635, 2.5803629307680644, 2.7326646074552103, 2.559237193255186,
                       3.13478196958166, 2.388760269933492, 3.2822443541491815, 2.0114405441787437, 3.0380056368041073,
                       2.4889680313769724, 2.821660164621084, 3.343985964873723, 3.1866861970287808, 4.4535037154856045,
                       3.0026333138006027, 2.0675706089352612, 2.3835301730913185, 2.584208398359566, 3.288077633446465,
                       2.6955853384148183, 2.918315169739928, 3.2464814419322985, 2.1601516779909433, 3.231003347780546,
                       1.0893224045062178, 0.8032302688764734, 2.868438615047827]

  historical_mean_scholarship = Normal([[2], [0.5]], 
    name = "historical_mean_scholarship")
  final_scholarship = historical_mean_scholarship + .03 * no_teacher

  # Define a summary statistics for final grade and final scholarship
  from abcpy.statistics import Identity
  statistics_final_grade = Identity(degree = 2, cross = False)
  statistics_final_scholarship = Identity(degree = 3, cross = False)

  # Define a distance measure for final grade and final scholarship
  from abcpy.distances import Euclidean
  distance_final_grade = Euclidean(statistics_final_grade)
  distance_final_scholarship = Euclidean(statistics_final_scholarship)
  # Define a backend
  from abcpy.backends import BackendDummy as Backend
  backend = Backend()
  # Define a perturbation kernel
  from abcpy.perturbationkernel import DefaultKernel
  kernel = DefaultKernel([school_budget, class_size,
    historical_mean_grade, no_teacher, historical_mean_scholarship])

  # Define sampling parameters
  T, n_sample, n_samples_per_param = 3, 250, 10
  eps_arr = np.array([.75])
  eps_percentile = 10
  # Define sampler
  from abcpy.inferences import PMCABC
  sampler = PMCABC([final_grade, final_scholarship],
    [distance_final_grade, distance_final_scholarship], backend, kernel)
  # Sample
  journal = sampler.sample([grades_obs, scholarship_obs], T, eps_arr,
    n_sample, n_samples_per_param, eps_percentile)

  # ----------------------------
  #
  # Section 6.4 
  #
  # ----------------------------

  from abcpy.perturbationkernel import MultivariateNormalKernel,MultivariateStudentTKernel
  kernel_1 = MultivariateNormalKernel([school_budget,
    historical_mean_grade, historical_mean_scholarship])
  kernel_2 = MultivariateStudentTKernel([class_size, no_teacher], df = 3)
  # Join the defined kernels
  from abcpy.perturbationkernel import JointPerturbationKernel
  kernel = JointPerturbationKernel([kernel_1, kernel_2])




