# -*- coding: utf-8 -*-
"""extending_romc_NN.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1_jHVxPSH3XcNOORZJpLU0SPzs0PF8CQ5

Extending ROMC with a Neural Network
"""

! pip install elfi

import timeit

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits import mplot3d
from sklearn.pipeline import Pipeline
from sklearn.neural_network import MLPRegressor
from functools import partial
from elfi.examples import ma2
import elfi
# from elfi.examples import ma2

np.random.seed(21)

# seed for reproducibility
seed = 1
np.random.seed(seed)
model = ma2.get_model(seed_obs=seed)

from elfi import examples

# Define a custom optimization class
dim = 2
n1 = 100
n2 = 200
bounds = [(-2, 2), (-1.25, 1.25)]
eps = .01
# vis_ind_1 = 1
# vis_ind_2 = 3
# vis_ind_3 = 12

# the custom optimization class
class CustomOptim(elfi.methods.inference.romc.OptimisationProblem):
    def __init__(self, **kwargs):
        super(CustomOptim, self).__init__(**kwargs)

    def fit_local_surrogate(self, **kwargs):
        nof_samples = 500
        objective = self.objective

        # helper function
        def local_surrogate(theta, model_scikit):
            assert theta.ndim == 1
            theta = np.expand_dims(theta, 0)
            return float(model_scikit.predict(theta))

        # create local surrogate model as a function of theta
        def create_local_surrogate(model):
            return partial(local_surrogate, model_scikit=model)

        local_surrogates = []
        for i in range(len(self.regions)):
            # prepare dataset
            x = self.regions[i].sample(nof_samples)
            y = np.array([objective(ii) for ii in x])

            # train Neural Network
            mlp = MLPRegressor(hidden_layer_sizes=(10,10), solver='adam')
            model = Pipeline([('linear', mlp)])
            model = model.fit(x, y)
            local_surrogates.append(create_local_surrogate(model))

        self.local_surrogates = local_surrogates
        self.state["local_surrogates"] = True

    @staticmethod
    def create_local_surrogate(model):
        def _local_surrogate(th):
            th = np.expand_dims(th, 0)
            return float(model.predict(th))
        return _local_surrogate

# initiate ROMC with custom_optim_class
romc = elfi.ROMC(model, bounds=bounds, discrepancy_name="d", custom_optim_class=CustomOptim)

# fitting part
romc.solve_problems(n1=n1, seed=seed)
romc.estimate_regions(eps_filter=eps, fit_models=True)

# sampling part
romc.sample(n2=n2, seed=seed)

print(romc.result.summary())
print(romc.result.samples_cov())

def plot_marginal(samples, weights, mean, std, title, xlabel, ylabel, bins, range, ylim):
    plt.figure()
    plt.title(title)
    plt.hist(x=samples,
             weights=weights,
             bins=bins, density=True, range=range)
    plt.axvline(mean, 0, 1,
                color="r", linestyle="--", label=r"$\mu = %.3f$" % (mean))
    plt.axhline(1,
                (mean-std-range[0])/(range[1] - range[0]),
                (mean+std-range[0])/(range[1] - range[0]),
                color="k",
                linestyle="--", label=r"$\sigma = %.3f$" % (std))
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.ylim(ylim)
    plt.legend()
    plt.show(block=False)

plot_marginal(romc.result.samples["t1"], romc.result.weights,
              romc.result.sample_means_array[0],
              np.sqrt(romc.result.samples_cov()[0, 0]),
              r"ROMC (Neural Network) - $\theta_1$",
              r"$\theta_1$",
              r"density",
              60,
              (0.3, 1.2), (0, 3.5))

plot_marginal(romc.result.samples["t2"], romc.result.weights,
              romc.result.sample_means_array[1],
              np.sqrt(romc.result.samples_cov()[1, 1]),
              r"ROMC (Neural Network) - $\theta_2$",
              r"$\theta_2$",
              r"density",
              60,
              (-0.5, 1), (0, 3))

def plot_romc_posterior(title, posterior, nof_points):
    plt.figure()
    th1 = np.linspace(bounds[0][0], bounds[0][1], nof_points)
    th2 = np.linspace(bounds[1][0], bounds[1][1], nof_points)
    X, Y = np.meshgrid(th1, th2)

    x_flat = X.flatten()
    y_flat = Y.flatten()
    th = np.stack((x_flat, y_flat), -1)
    z_flat = posterior(th)
    Z = z_flat.reshape(nof_points, nof_points)

    plt.contourf(X, Y, Z, 50, cmap='viridis')
    plt.title(title)
    plt.xlabel(r"$\theta_1$")
    plt.ylabel(r"$\theta_2$")
    plt.colorbar()
    plt.show(block=False)

plot_romc_posterior('ROMC (Neural Network)',
                        romc.eval_unnorm_posterior,
                        nof_points=50)