# -*- coding: utf-8 -*-
"""example_MA2.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1nkdACQ370SSc0KB1bHv4sBRaxMlMqoNH
"""

# install dev branch of ELFI
! pip install elfi

import numpy as np
import scipy.stats
import matplotlib
import matplotlib.pyplot as plt

# %matplotlib inline
# %precision 2

import logging
logging.basicConfig(level=logging.INFO)

# Set seed for reproducibility
seed = 1
np.random.seed(seed)

import elfi
from elfi.examples import ma2
# from elfi.methods.parameter_inference import ROMC
model = ma2.get_model(seed_obs=seed)

x = model.generate(1000, outputs=["t1","t2"])

plt.figure()
plt.title("Samples from the prior")
plt.plot(x["t1"], x["t2"], "bo")
plt.show()

"""# Training part, using gradient-based optimisation"""

bounds = [(-2,2), (-2,2)]
romc = elfi.ROMC(model, bounds=bounds, discrepancy_name="d")

n1 = 300
seed = 21
romc.solve_problems(n1=n1, seed=seed)

romc.distance_hist(bins=40)

eps_filter = .02
romc.estimate_regions(eps_filter=eps_filter, fit_models=True, eps_cutoff=0.1)

romc.visualize_region(5)

n2 = 50
tmp = romc.sample(n2=n2)

# As before but now it plots the samples as well
romc.visualize_region(5)

"""# Rejection ABC - used as ground-truth information"""

# Commented out IPython magic to ensure Python compatibility.
N=10000
rej = elfi.Rejection(model, discrepancy_name="d", batch_size=10000, seed=seed)
vis = dict(xlim=[-2,2], ylim=[-1,1])
# %time result = rej.sample(N, threshold=.1, vis=vis)

result.plot_marginals(range=[-.4,1], bins=70)

result.summary()

plt.figure()
plt.hist2d(result.samples_array[:,0], result.samples_array[:,1], bins=60, range=[(-1, 1.5), (-1,1)], weights=result.weights)
plt.plot(block=False)

"""# ROMC Evaluation"""

romc.result.plot_marginals(weights=romc.result.weights, bins=70, range=(-.4, 1))
plt.show()

romc.result.summary()

plt.figure()
plt.hist2d(romc.result.samples_array[:,0], romc.result.samples_array[:,1], bins=60, range=[(-1, 1.5), (-1,1)], weights=romc.result.weights)
plt.plot(block=False)

def plot_romc_posterior(posterior, nof_points):
    plt.figure()
    th1 = np.linspace(-1, 1.5, nof_points)
    th2 = np.linspace(-1, 1, nof_points)
    X, Y = np.meshgrid(th1, th2)

    x_flat = X.flatten()
    y_flat = Y.flatten()
    th = np.stack((x_flat, y_flat), -1)
    z_flat = posterior(th)
    Z = z_flat.reshape(nof_points, nof_points)

    plt.contour(X, Y, Z, 50, cmap='viridis')
    plt.title('ROMC Posterior PDF')
    plt.xlabel("th_1")
    plt.ylabel("th_2")
    plt.colorbar()
    plt.show(block=False)

plot_romc_posterior(romc.eval_posterior, nof_points=80)

"""# Training Part, using Bayesian Optimisation"""

bounds = [(-2,2), (-2,2)]
romc = elfi.ROMC(model, bounds=bounds, discrepancy_name="d")

n1 = 100
seed = 21
romc.solve_problems(n1=n1, seed=seed, use_bo=True)

romc.distance_hist(bins=40)

eps_filter = .02
romc.estimate_regions(eps_filter=eps_filter, fit_models=True, eps_cutoff=0.1)

romc.visualize_region(5)

n2 = 50
tmp = romc.sample(n2=n2)

# As before but now it plots the samples as well
romc.visualize_region(5)

"""# Evaluation"""

romc.result.plot_marginals(weights=romc.result.weights, bins=70, range=(-.4, 1))
plt.show()

romc.result.summary()

plt.figure()
plt.hist2d(romc.result.samples_array[:,0], romc.result.samples_array[:,1], bins=60, range=[(-1, 1.5), (-1,1)], weights=romc.result.weights)
plt.plot(block=False)

def plot_romc_posterior(posterior, nof_points):
    plt.figure()
    th1 = np.linspace(-1, 1.5, nof_points)
    th2 = np.linspace(-1, 1, nof_points)
    X, Y = np.meshgrid(th1, th2)

    x_flat = X.flatten()
    y_flat = Y.flatten()
    th = np.stack((x_flat, y_flat), -1)
    z_flat = posterior(th)
    Z = z_flat.reshape(nof_points, nof_points)

    plt.contour(X, Y, Z, 50, cmap='viridis')
    plt.title('ROMC Posterior PDF')
    plt.xlabel("th_1")
    plt.ylabel("th_2")
    plt.colorbar()
    plt.show(block=False)

plot_romc_posterior(romc.eval_posterior, nof_points=80)