import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt
import seaborn as sns
from tabulate import tabulate

from gcimpute.gaussian_copula import GaussianCopula
from gcimpute.low_rank_gaussian_copula import LowRankGaussianCopula
from gcimpute.helper_data import load_GSS, load_movielens1m, load_FRED, load_whitewine
from gcimpute.helper_mask import mask_MCAR
from gcimpute.helper_evaluation import get_smae, get_mae

from matplotlib.backends.backend_pdf import PdfPages
import warnings


def run_GSS(loc_plots):
	# plot histogram
	data_gss = load_GSS()
	fig, axes = plt.subplots(2, 4, figsize=(12,6))
	for i,col in enumerate(data_gss):
	    ax_index = np.unravel_index(i, (2, 4))
	    if col in ['AGE', 'WEEKSWRK']:
	        data_gss[col].dropna().hist(ax=axes[ax_index], bins=60)
	    else:
	        to_plot=data_gss[col].dropna().value_counts().sort_index()
	        to_plot.plot(kind='bar', ax=axes[ax_index])
	    _title = f'{col}, {data_gss[col].isna().mean():.2f} missing'
	    axes[ax_index].set_title(_title)
	plt.tight_layout()
	fig.savefig(loc_plots, bbox_inches='tight', format='pdf')

	# mask data, impute and evaluate
	gss_masked = mask_MCAR(X = data_gss, mask_fraction = .1)
	model = GaussianCopula()
	Ximp = model.fit_transform(X = gss_masked)
	smae = get_smae(Ximp, x_true = data_gss, x_obs = gss_masked)
	print(f'SMAE average over all variables: {smae.mean():.3f}')

	# plot correlation 
	corr_est = model.get_params()['copula_corr'].round(2)
	mask = np.zeros_like(corr_est)
	mask[np.triu_indices_from(mask)] = True
	names = data_gss.columns
	fig, ax = plt.subplots()
	sns.heatmap(
		corr_est, 
		xticklabels=names, yticklabels=names, 
		annot=True, mask=mask, square=True, cmap='vlag'
	)
	fig.savefig(loc_plots, bbox_inches='tight', format='pdf')

	# variable types information 
	for k, v in model.get_vartypes(feature_names = names).items():
	    print(f'{k}: {", ".join(v)}')

	def key_freq(col):
	    freq = col.value_counts(normalize = True)
	    _min, _max = col.min(), col.max()
	    freqmid = freq.drop(index = [_min, _max])
	    key_freq = {
	        'mode': freq.max(), 
	        'min': freq[_min], 
	        'max': freq[_max],
	        'mode_freq_nominmax': freqmid.max()/freqmid.sum()
	    }
	    return pd.Series(key_freq).round(2)
	table = data_gss.apply(lambda x: key_freq(x.dropna())).T
	print(tabulate(table, headers = 'keys', tablefmt = 'psql'))

	# with output 
	model = GaussianCopula(verbose = 1)
	Ximp = model.fit_transform(X = gss_masked)

	# fit transform and evaluate
	m = GaussianCopula(verbose = 1)
	def err(x):
	     return get_smae(x, x_true = data_gss, x_obs = gss_masked).mean()
	r = m.fit_transform_evaluate(gss_masked, eval_func = err, num_iter = 15)
	fig, ax = plt.subplots()
	ax.plot(list(range(1, 16, 1)), r['evaluation'])
	ax.set_title('Imputation error versus run iterations')
	ax.set_xlabel("Run iterations")
	ax.set_ylabel("SMAE")
	fig.savefig(loc_plots, bbox_inches='tight', format='pdf')

	# minibatch 
	t1 = time.time()
	model_minibatch = GaussianCopula(training_mode = 'minibatch-offline')
	Ximp_batch = model_minibatch.fit_transform(X = gss_masked)
	t2 = time.time()
	print(f'Runtime: {t2 - t1:.2f} seconds')
	smae_batch = get_smae(Ximp_batch, x_true = data_gss, x_obs = gss_masked)
	print(f'Imputation error: {smae_batch.mean():.3f}')

	t1 = time.time()
	GaussianCopula().fit_transform(X = gss_masked)
	t2 = time.time()
	print(f'Runtime: {t2 - t1:.2f} seconds')

def run_movielens(loc_plots):
	data_movie = load_movielens1m(num = 400, min_obs = 150)
	movie_masked = mask_MCAR(X = data_movie, mask_fraction = 0.1)

	a = time.time()
	model_movie_lrgc = LowRankGaussianCopula(rank = 10)
	imp_lrgc = model_movie_lrgc.fit_transform(X = movie_masked)
	print(f'LRGC runtime {time.time() - a:.2f} seconds.')
	a = time.time()
	model_movie_gc = GaussianCopula()
	imp_gc = model_movie_gc.fit_transform(X = movie_masked)
	print(f'GC runtime {time.time() - a:.2f} seconds.')

	mae_gc = get_mae(x_imp = imp_gc, x_true = data_movie, x_obs = movie_masked)
	mae_lrgc = get_mae(x_imp = imp_lrgc, x_true = data_movie, x_obs = movie_masked)
	print(f'LRGC imputation MAE: {mae_lrgc:.3f}')
	print(f'GC imputation MAE: {mae_gc:.3f}')

def run_FRED(loc_plots):
	fred = load_FRED()
	fred.plot(
	    subplots = True, layout = (2,4), figsize = (16, 6),
	    legend = False, title = fred.columns.to_list()
	)
	plt.savefig(loc_plots, bbox_inches='tight', format='pdf')
	model = GaussianCopula(
	    training_mode = 'minibatch-online', 
	    window_size = 10,
	    const_stepsize = 0.1, 
	    batch_size = 10, 
	    decay = 0.01
	)
	Xmasked = fred.assign(StockVolatility = np.nan, CrudeOilPrice = np.nan)
	Ximp = model.fit_transform(Xmasked, X_true = fred, n_train = 25)

	n_train = 25
	for i, col in enumerate(['CrudeOilPrice', 'StockVolatility']):
	    _true = fred[col][n_train:].to_numpy()
	    _err_yes = fred[col][n_train - 1:-1].to_numpy() - _true
	    _err_GC = Ximp[n_train:, i] - _true
	    print(f'MSE of {col}:')
	    print(f'Gaussian Copula Pred: {np.power(_err_GC, 2).mean():.3f}')
	    print(f'Yesterday Value Pred: {np.power(_err_yes, 2).mean():.3f}')

def run_whitewine(loc_plots):
	from sklearn.metrics import mean_squared_error as MSE
	from sklearn.linear_model import LinearRegression as LR
	wine = load_whitewine()
	print(tabulate(wine.head().T, headers = 'keys', tablefmt = 'psql'))

	X = wine.to_numpy()[:, :-1]
	Xmasked = mask_MCAR(X, mask_fraction = 0.3)
	model_wine = GaussianCopula()
	Ximputed = model_wine.fit_transform(X = Xmasked)

	Xtrain, Xtest = X[:4000], X[4000:]
	y = wine['quality']
	ytrain, ytest = y[:4000], y[4000:]
	ypred = LR().fit(Xtrain, ytrain).predict(Xtest)
	print(np.round(MSE(ytest, ypred),4))

	Xtrain_imp, Xtest_imp = Ximputed[:4000], Ximputed[4000:]
	ypred_imp = LR().fit(Xtrain_imp, ytrain).predict(Xtest_imp)
	print(np.round(MSE(ytest, ypred_imp), 4))

	Ximputed_mul = model_wine.sample_imputation(Xmasked, num = 5)
	ypred_mul_imputed = []
	for i in range(5):
	    Ximputed = Ximputed_mul[..., i]
	    _Xtrain_imp, _Xtest_imp = Ximputed[:4000], Ximputed[4000:]
	    _ypred = LR().fit(_Xtrain_imp, ytrain).predict(_Xtest_imp)
	    ypred_mul_imputed.append(_ypred)
	ypred_mul_imputed = np.array(ypred_mul_imputed).mean(axis = 0)
	print(np.round(MSE(ytest, ypred_mul_imputed), 4))

	ct = model_wine.get_confidence_interval()
	upper, lower = ct['upper'], ct['lower']
	missing = np.isnan(Xmasked)
	Xmissing = X[missing]
	cover = (lower[missing] < Xmissing) & (upper[missing] > Xmissing)
	print(np.round(cover.mean(), 3))

	ct_q = model_wine.get_confidence_interval(type = 'quantile')
	upper_q, lower_q = ct_q['upper'], ct_q['lower']
	cover_q = (lower_q[missing] < Xmissing) & (upper_q[missing] > Xmissing)
	print(np.round(cover_q.mean(), 3))

def main():
	a0 = time.time()
	runs = {'GSS':run_GSS, 'movielens':run_movielens, 'FRED':run_FRED, 'whitewine':run_whitewine}
	with PdfPages("Figures/plots_output.pdf") as pdf:
		for name, f_run in runs.items():
			a = time.time()
			f_run(pdf)
			b = time.time()
			print(f'Finish running {name} experiments after {(b-a)/60:.2f} mins')
	b0 = time.time()
	print(f'Finish running all experiments after {(b0-a0)/60:.2f} mins')

if __name__ == "__main__":
    main()

