import os
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from boxhed.boxhed import boxhed
from boxhed.utils import timer, run_as_process
from packages.boxhed1 import BoXHED # importing BoXHED 1.0

def set_context_to_fork():
    import multiprocessing
    import platform

    if 'fork' in multiprocessing.get_all_start_methods():
        multiprocessing.set_start_method('fork', force=True)
    else:
        raise OSError(f"ERROR! The current platform {platform.system()} does not support 'fork' process start method which is used in implementing"
                    " the simulation code. Consider using a MacOS or a Linux machine to run this code.")

# This function reads in the runtime files and creates two numpy arrays of number of rows and the corresponding values
def runtimes_as_dict(fname):
    with open(fname) as f:
        lines = [line.rstrip().split(sep=':') for line in f]
    nrows, runtimes = map(list, zip(*lines))
    return [np.array(list(map(type, arr))) for arr, type in [[nrows, lambda x:int(float(x))],[runtimes, float]]]


def section_5_download_train_data(train_data_addr):
    if not os.path.isfile(train_data_addr):
        import gdown
        gdown.download(id='11Mh9JrXBnmEmNX0C1JvaTUKDuksq9KX7', output = train_data_addr, quiet=False)


def plot(plot_configs, xlabel, ylabel, plot_attr_dict={}):
    font_size = 40
    plt.figure(figsize=(19,12), dpi=300)
    plt.subplots_adjust(left=0.14, bottom=0.14, right=0.98, top=0.98)
    plt.xticks(fontsize= font_size)
    plt.yticks(fontsize= font_size)
    plt.xlabel(xlabel, fontsize=font_size)
    plt.ylabel(ylabel, fontsize=font_size)

    for plot_config in plot_configs:
        sns.regplot(x=plot_config['nrows'], y=plot_config['runtimes'], ci=None, color=plot_config['color'], order = 1, marker = plot_config['marker'], label = plot_config['label'],
                scatter_kws={'s':plot_config['s'], 'linewidth':plot_config['lw']}, 
                line_kws = {'linewidth':plot_config['lw'], 'linestyle': plot_config['ls']})

    for attr, val in plot_attr_dict.items():
        getattr(plt, attr)(val)

    lgnd = plt.legend(frameon=True, prop={'size': font_size})
    plt.tight_layout()
    plt.show()


# This function times boxhed and saves the runtimes in a text file
def section_5_scalability_boxhed_by_nrow(gpu_id, nthread, train_data_addr, runtimes_file):
    set_context_to_fork()
    @run_as_process
    def section_5_boxhed_train_save_runtime(gpu_id, nthread, nrow, train_data_addr, runtimes_file):
        train_data = pd.read_csv(train_data_addr, compression='gzip', nrows = nrow) # subset the data based on the number of rows
        boxhed_ = boxhed(gpu_id=gpu_id, nthread=nthread, max_depth = 1, n_estimators = 250)

        # start runtime
        timer_    = timer()

        # preprocess the data
        X_post   = boxhed_.preprocess(data = train_data,
                                num_quantiles = 256,
                                weighted      = False,
                                nthread       = nthread)

        boxhed_.fit(X_post['X'], X_post['delta'], X_post['w'])

        fit_dur  = timer_.get_dur()

        with open(runtimes_file, 'a+') as f:
            f.write(f'{int(nrow)}:{fit_dur}\n')

        del boxhed_

    if os.path.exists(runtimes_file):
        os.remove(runtimes_file)

    nrows = [2e6, 4e6, 6e6, 8e6]
    if gpu_id == -1: #add 10mm rows if CPU, it cannot fit into GPU
        nrows += [10e6]
    for nrow in tqdm(nrows, desc=f'Measuring BoXHED {"CPU" if gpu_id==-1 else "GPU"} scalability by # rows.'):
        section_5_boxhed_train_save_runtime(gpu_id, nthread, nrow, train_data_addr, runtimes_file)


# This function times boxhed and saves the runtimes in a text file
def section_5_boxhed_one_runtime(train_data_addr, runtimes_file):
    set_context_to_fork()

    # This function transforms the data as expected by BoXHED2.0 and transforms it to how BoXHED1.0 expects it.
    def section_5_2_to_1_data_convert(data):
        from MIMIC_IV_extractor.util import get_episode_IDs

        def chng_idxs(arr):
            arr = np.array(arr, dtype='object')
            return np.where(np.concatenate(((arr[1:] != arr[:-1]), [True])))

        def frmat_cnvrt_two_to_one(ep_data):
            if not np.array_equal(
                            ep_data['t_start'].iloc[1:].values,
                            ep_data['t_end'].iloc[:-1].values):
                raise ValueError("ERROR: This function does not handle non-contiguous timing."
                                f" Check t_start and t_end values for ID={ep_data['ID'].iloc[0]}")                          
            ep_data = pd.concat([ep_data, ep_data.iloc[[-1]]], axis=0, ignore_index=True)
            ep_data['t_start'].iloc[-1] = ep_data['t_end'].iloc[-1]
            return ep_data

        data['EP_ID'] = get_episode_IDs(*[data[col].values for col in ['ID', 'delta']])
        data          = data.groupby('EP_ID').apply(frmat_cnvrt_two_to_one).reset_index(drop=True)
        delta  = data['delta'].iloc[chng_idxs(data['EP_ID'])].values
        data.drop(columns=['ID', 'delta', 't_end'], inplace=True)
        lotraj = [ep_data.drop(columns='EP_ID').values for _, ep_data in data.groupby('EP_ID')]
        return lotraj, delta

    @run_as_process
    def section_5_boxhed_one_train_save_runtime(nrow, train_data_addr, runtimes_file):
        train_data = pd.read_csv(train_data_addr, compression='gzip', nrows = nrow) # subset the data based on the number of rows

        lotraj, delta = section_5_2_to_1_data_convert(train_data)

        # start runtime
        timer_    = timer()

        boxhed_ = BoXHED.BoXHED(delta, lotraj, maxsplits=1, numtrees=1, numtimepartitions=10, numvarpartitions=10)

        fit_dur  = timer_.get_dur()

        with open(runtimes_file, 'a+') as f:
            f.write(f'{int(nrow)}:{fit_dur}\n')

        del boxhed_

    if os.path.exists(runtimes_file):
        os.remove(runtimes_file)

    for nrow in tqdm([2e6, 4e6, 6e6], desc=f'Measuring BoXHED 1.0 scalability by # rows.'):
        section_5_boxhed_one_train_save_runtime(nrow, train_data_addr, runtimes_file)


# This function plots the runtimes of boxhed and blackboost
def section_5_plot_boxhed_two_one (boxhed_cpu_runtimes_file, boxhed_gpu_runtimes_file, boxhed_one_runtimes_file):

    boxhed_cpu_nrows, boxhed_cpu_runtimes = runtimes_as_dict(boxhed_cpu_runtimes_file)
    boxhed_gpu_nrows, boxhed_gpu_runtimes = runtimes_as_dict(boxhed_gpu_runtimes_file)
    boxhed_one_nrows, boxhed_one_runtimes = runtimes_as_dict(boxhed_one_runtimes_file)

    boxhed_cpu_nrows, boxhed_cpu_runtimes, boxhed_gpu_nrows, boxhed_gpu_runtimes = [
            arr[:len(boxhed_one_nrows)] 
                for arr in [boxhed_cpu_nrows, boxhed_cpu_runtimes, boxhed_gpu_nrows, boxhed_gpu_runtimes]
        ]

    plot_configs = [
        {
            'label':   'BoXHED1.0',
            'nrows':    boxhed_one_nrows/1e6,
            'runtimes': boxhed_one_runtimes,
            'color':   'blue',
            'marker':  '*',
            'lw':       2,
            'ls':      (0, (5, 10)),
            's':        600
        },
        {
            'label':   'BoXHED2.0 CPU',
            'nrows':    boxhed_cpu_nrows/1e6,
            'runtimes': boxhed_cpu_runtimes,
            'color':   'k' ,
            'marker':  's',
            'lw':       2,
            'ls':      (0, (5, 10)),
            's':        600
        },
        {
            'label':   'BoXHED2.0 GPU',
            'nrows':    boxhed_gpu_nrows/1e6,
            'runtimes': boxhed_gpu_runtimes,
            'color':   'purple' ,
            'marker':  'o',
            'lw':       2,
            'ls':      'solid',
            's':        800
        }
    ]
    
    plot(plot_configs, '# rows (in millions)', 'time (sec) on log scale',  {'xlim':[1.8, 6.2], 'yscale': 'log'})


# This function plots the runtimes of boxhed and blackboost
def section_5_plot_boxhed_blackboost_runtimes (boxhed_cpu_runtimes_file, boxhed_gpu_runtimes_file, blackboost_runtimes_file):

    boxhed_cpu_nrows, boxhed_cpu_runtimes = runtimes_as_dict(boxhed_cpu_runtimes_file)
    boxhed_gpu_nrows, boxhed_gpu_runtimes = runtimes_as_dict(boxhed_gpu_runtimes_file)
    blackboost_nrows, blackboost_runtimes = runtimes_as_dict(blackboost_runtimes_file)

    plot_configs = [
        {
            'label':   'Blackboost',
            'nrows':    blackboost_nrows/1e6,
            'runtimes': blackboost_runtimes,
            'color':   'blue',
            'marker':  '*',
            'lw':       2,
            'ls':      (0, (5, 10)),
            's':        600
        },
        {
            'label':   'BoXHED2.0 CPU',
            'nrows':    boxhed_cpu_nrows/1e6,
            'runtimes': boxhed_cpu_runtimes,
            'color':   'k' ,
            'marker':  's',
            'lw':       2,
            'ls':      (0, (5, 10)),
            's':        600
        },
        {
            'label':   'BoXHED2.0 GPU',
            'nrows':    boxhed_gpu_nrows/1e6,
            'runtimes': boxhed_gpu_runtimes,
            'color':   'purple' ,
            'marker':  'o',
            'lw':       2,
            'ls':      'solid',
            's':        800
        }
    ]

    plot(plot_configs, '# rows (in millions)', 'time (sec)',  {'xlim':[1.8, 10.2]})


# This function times boxhed and saves the runtimes in a text file
def section_5_scalability_boxhed_by_ncov(gpu_id, nthread, train_data_addr, runtimes_file):
    set_context_to_fork()
    @run_as_process
    def boxhed_train_save_runtime(gpu_id, nthread, nrow, ncovs, train_data_addr, runtimes_file):
        train_data = pd.read_csv(train_data_addr, compression='gzip', 
                nrows = nrow, 
                usecols=['ID', 't_start', 't_end']+[f"X_{i}" for i in range(ncovs)]+['delta']) # subset the data based on the number of rows
        boxhed_ = boxhed(gpu_id=gpu_id, nthread=nthread, max_depth = 1, n_estimators = 250)

        # start runtime
        timer_    = timer()

        # preprocess the data
        X_post   = boxhed_.preprocess(data = train_data,
                                num_quantiles = 256,
                                weighted      = False,
                                nthread       = nthread)

        boxhed_.fit(X_post['X'], X_post['delta'], X_post['w'])

        fit_dur  = timer_.get_dur()

        with open(runtimes_file, 'a+') as f:
            f.write(f'{int(ncovs)}:{fit_dur}\n')

        del boxhed_

    if os.path.exists(runtimes_file):
        os.remove(runtimes_file)

    nrow = int(4e6)
    
    for ncovs in tqdm(list(map(int, np.linspace(0, 40, num=9))), 
                desc=f'Measuring BoXHED {"CPU" if gpu_id==-1 else "GPU"} scalability by # covariates.'):
        boxhed_train_save_runtime(gpu_id, nthread, nrow, ncovs, train_data_addr, runtimes_file)


def section_5_plot_boxhed_scalability_ncovs (boxhed_cpu_file, boxhed_gpu_file):

    boxhed_cpu_ncovs, boxhed_cpu_runtimes = runtimes_as_dict(boxhed_cpu_file)
    boxhed_gpu_ncovs, boxhed_gpu_runtimes = runtimes_as_dict(boxhed_gpu_file)

    plot_configs = [
        {
            'label':   'BoXHED2.0 CPU',
            'nrows':    boxhed_cpu_ncovs,
            'runtimes': boxhed_cpu_runtimes,
            'color':   'k' ,
            'marker':  's',
            'lw':       2,
            'ls':      (0, (5, 10)),
            's':        600
        },
        {
            'label':   'BoXHED2.0 GPU',
            'nrows':    boxhed_gpu_ncovs,
            'runtimes': boxhed_gpu_runtimes,
            'color':   'purple' ,
            'marker':  'o',
            'lw':       2,
            'ls':      'solid',
            's':        800
        }
    ]

    plot(plot_configs, '# covs', 'time (sec)', {'xlim':[-2, 42]})



def section_6_train_boxhed_on_MIMIC_IV_iV(training_data):
    nthread = 20

    boxhed_ = boxhed(nthread = nthread)
    X_post  = boxhed_.preprocess(data = training_data, is_cat = list(range(3, 11+1)),
                                            num_quantiles = 256, weighted = False, nthread = nthread)

    boxhed_.fit(X_post['X'], X_post['delta'], X_post['w'])

    return boxhed_


def dump_pickle(obj, addr):
    with open(addr, 'wb') as handle:
        pickle.dump(obj, handle) 


def load_pickle(addr):
    with open(addr, 'rb') as handle:
        obj = pickle.load(handle)
    return obj 


def section_6_plot_hazard_over_time(times, hazards):
    font_size = 40
    ms   = 25
    lw   = 8

    fig, ax = plt.subplots(figsize=(20, 12), dpi=300)

    ax.spines['left'].set_position(('data', 24))
    ax.spines['bottom'].set_position(('data', 0))
    ax.spines['left'].set(lw=lw)
    ax.spines['bottom'].set(lw=lw)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.plot(1, 0, '>k', ms=ms, transform=ax.get_yaxis_transform(), clip_on=False)
    ax.plot(24, 1, '^k', ms=ms, transform=ax.get_xaxis_transform(), clip_on=False)

    ax.text(1.03, -0.03*hazards.max(), s=r'$t$ $(hour)$', fontsize=font_size+20, transform=ax.get_yaxis_transform(), clip_on=False)
    ax.text(-0.11, 1.05*hazards.max(), s=r'$\hat\lambda(t, X(t))$', fontsize=font_size+20, transform=ax.get_yaxis_transform(), clip_on=False)

    plt.xlim([times.min(), times.max()+2])
    plt.ylim([0, hazards.max()])

    ax.plot(times, hazards, lw=4, color='k')
    
    xticks_ = [24*i for i in range(1, int(times.max()/24)+1)]
    plt.xticks(xticks_, xticks_, fontsize= font_size)
    plt.yticks(fontsize= font_size)
    ax.tick_params(axis='both', which='major', pad=15)
    plt.tight_layout()
    plt.show()
    #fig.savefig(fname)


# This function plots the variable importances in BoXHED
def section_6_plot_boxhed_var_imps(varimps, top_k=10):
    def plot_var_imps(vars, imps):
        vars = vars[:top_k]
        imps = imps[:top_k]

        font_size = 40
        _, axis = plt.subplots(figsize=(20,12), dpi=300)
        plt.xticks(fontsize= font_size-6)
        plt.yticks(fontsize= font_size)
        #plt.title('Relative Variable Importance', fontsize=font_size)
        plt.bar(vars, imps, color='steelblue')
        plt.xticks(rotation = -90, weight='bold')
        labels = axis.set_xticklabels(vars)
        for label in labels:
            label.set_y(label.get_position()[1]+0.967)
        plt.tight_layout()
        plt.show()

    # converting the dictionary of variable:importance to two separate lists: variables, importances
    vars, imps    = [np.array(arr) for arr in zip(*list(varimps.items()))]
    srtd_imp_idxs = imps.argsort()[::-1]
    vars          = vars[srtd_imp_idxs]
    imps          = imps[srtd_imp_idxs]
    imps          = imps/imps.max()

    transform = {
        'Fraction inspired oxygen': 'Fraction inspired O2',
        'Glascow coma scale total': 'Glascow coma scale',
        'Oxygen saturation':        'O2 saturation',
        'Heart Rate':               'Heart rate',
    }

    vars = [transform[var] if var in transform else var for var in vars]

    plot_var_imps(vars, imps)
