from __future__ import absolute_import
from __future__ import print_function


import os
import sys
import csv
import yaml
import random
import shutil
import pathlib
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from sklearn.model_selection import train_test_split
random.seed(49297)
tqdm.pandas()

from . import mimic4csv
from . import preprocessing
from . import subject
from . import util


#from mimic4csv import *



VERBOSE       = True

def log(msg):
    print ("\n\n")
    print ("~ ~ "*20)
    print (" "*4+msg)
    print ("~ ~ "*20)
    print ("\n\n")


def extract_subjects(mimic4_path, output_path):
    log("extracting subjects")

    try:
        os.makedirs(output_path)
    except FileExistsError:
        pass

    patients = mimic4csv.read_patients_table(mimic4_path)
    admits = mimic4csv.read_admissions_table(mimic4_path)
    stays = mimic4csv.read_icustays_table(mimic4_path)
    if VERBOSE:
        print('BEFORE REMOVALS:\n\tICUSTAY_IDs: {}\n\tHADM_IDs: {}\n\tSUBJECT_IDs: {}'.format(stays.ICUSTAY_ID.unique().shape[0],
            stays.HADM_ID.unique().shape[0], stays.SUBJECT_ID.unique().shape[0]))

    stays = mimic4csv.remove_icustays_with_transfers(stays)
    if VERBOSE:
        print('REMOVE ICU TRANSFERS:\n\tICUSTAY_IDs: {}\n\tHADM_IDs: {}\n\tSUBJECT_IDs: {}'.format(stays.ICUSTAY_ID.unique().shape[0],
            stays.HADM_ID.unique().shape[0], stays.SUBJECT_ID.unique().shape[0]))

    stays = mimic4csv.merge_on_subject_admission(stays, admits)
    stays = mimic4csv.merge_on_subject(stays, patients)
    stays = mimic4csv.filter_admissions_on_nb_icustays(stays)
    if VERBOSE:
        print('REMOVE MULTIPLE STAYS PER ADMIT:\n\tICUSTAY_IDs: {}\n\tHADM_IDs: {}\n\tSUBJECT_IDs: {}'.format(stays.ICUSTAY_ID.unique().shape[0],
            stays.HADM_ID.unique().shape[0], stays.SUBJECT_ID.unique().shape[0]))

    stays = mimic4csv.add_age_to_icustays(stays)
    stays = mimic4csv.add_inunit_mortality_to_icustays(stays)
    stays = mimic4csv.add_inhospital_mortality_to_icustays(stays)
    stays = mimic4csv.filter_icustays_on_age(stays)
    if VERBOSE:
        print('REMOVE PATIENTS AGE < 18:\n\tICUSTAY_IDs: {}\n\tHADM_IDs: {}\n\tSUBJECT_IDs: {}'.format(stays.ICUSTAY_ID.unique().shape[0],
            stays.HADM_ID.unique().shape[0], stays.SUBJECT_ID.unique().shape[0]))

    stays.to_csv(os.path.join(output_path, 'all_stays.csv'), index=False)
    diagnoses = mimic4csv.read_icd_diagnoses_table(mimic4_path)
    diagnoses = mimic4csv.filter_diagnoses_on_stays(diagnoses, stays)

    diagnoses.to_csv(os.path.join(output_path, 'all_diagnoses.csv'), index=False)
    mimic4csv.count_icd_codes(diagnoses, output_path=os.path.join(output_path, 'diagnosis_counts.csv'))

    with open(os.path.join(pathlib.Path(__file__).parent.resolve(), "hcup_ccs_2015_definitions.yaml"), 'r') as yaml_f:
        phenotypes = preprocessing.add_hcup_ccs_2015_groups(diagnoses, yaml.safe_load(yaml_f))
    preprocessing.make_phenotype_label_matrix(phenotypes, stays).to_csv(os.path.join(output_path, 'phenotype_labels.csv'),
                                                        index=False, quoting=csv.QUOTE_NONNUMERIC)

    subjects = stays.SUBJECT_ID.unique()
    subjects = subjects[random.sample(range(len(subjects)), 1001)]

    mimic4csv.break_up_stays_by_subject(stays, output_path, subjects=subjects)
    mimic4csv.break_up_diagnoses_by_subject(phenotypes, output_path, subjects=subjects)


    for table in ['icu/chartevents', 'hosp/labevents', 'icu/outputevents']:
        mimic4csv.read_events_table_and_break_up_by_subject(mimic4_path, table, output_path, items_to_keep=None,
                                                subjects_to_keep=subjects)


def validate_events(path):

    log("validating events")

    n_events = 0                   # total number of events
    empty_hadm = 0                 # HADM_ID is empty in events.csv. We exclude such events.
    no_hadm_in_stay = 0            # HADM_ID does not appear in stays.csv. We exclude such events.
    no_icustay = 0                 # ICUSTAY_ID is empty in events.csv. We try to fix such events.
    recovered = 0                  # empty ICUSTAY_IDs are recovered according to stays.csv files (given HADM_ID)
    could_not_recover = 0          # empty ICUSTAY_IDs that are not recovered. This should be zero.
    icustay_missing_in_stays = 0   # ICUSTAY_ID does not appear in stays.csv. We exclude such events.

    subdirectories = os.listdir(path)
    subjects = list(filter(lambda x:str.isdigit(x), subdirectories))

    for subject in tqdm(subjects, desc='Iterating over subjects'):
        stays_df = pd.read_csv(os.path.join(path, subject, 'stays.csv'), index_col=False,
                               dtype={'HADM_ID': str, "ICUSTAY_ID": str})
        
        stays_df.columns = stays_df.columns.str.upper()

        # assert that there is no row with empty ICUSTAY_ID or HADM_ID
        assert(not stays_df['ICUSTAY_ID'].isnull().any())
        assert(not stays_df['HADM_ID'].isnull().any())

        # assert there are no repetitions of ICUSTAY_ID or HADM_ID
        # since admissions with multiple ICU stays were excluded
        assert(len(stays_df['ICUSTAY_ID'].unique()) == len(stays_df['ICUSTAY_ID']))
        assert(len(stays_df['HADM_ID'].unique()) == len(stays_df['HADM_ID']))

        # If subject does not have events table
        try:
            events_df = pd.read_csv(os.path.join(path, subject, 'events.csv'), index_col=False,
                                dtype={'HADM_ID': str, "ICUSTAY_ID": str})
        except:
            continue

        events_df.columns = events_df.columns.str.upper()
        n_events += events_df.shape[0]

        # we drop all events for them HADM_ID is empty
        empty_hadm += events_df['HADM_ID'].isnull().sum()
        events_df = events_df.dropna(subset=['HADM_ID'])

        merged_df = events_df.merge(stays_df, left_on=['HADM_ID'], right_on=['HADM_ID'],
                                    how='left', suffixes=['', '_r'], indicator=True)

        # we drop all events for which HADM_ID is not listed in stays.csv
        # since there is no way to know the targets of that stay (for example mortality)
        no_hadm_in_stay += (merged_df['_merge'] == 'left_only').sum()
        merged_df = merged_df[merged_df['_merge'] == 'both']

        # if ICUSTAY_ID is empty in stays.csv, we try to recover it
        # we exclude all events for which we could not recover ICUSTAY_ID
        cur_no_icustay = merged_df['ICUSTAY_ID'].isnull().sum()
        no_icustay += cur_no_icustay
        merged_df.loc[:, 'ICUSTAY_ID'] = merged_df['ICUSTAY_ID'].fillna(merged_df['ICUSTAY_ID_r'])
        recovered += cur_no_icustay - merged_df['ICUSTAY_ID'].isnull().sum()
        could_not_recover += merged_df['ICUSTAY_ID'].isnull().sum()
        merged_df = merged_df.dropna(subset=['ICUSTAY_ID'])

        # now we take a look at the case when ICUSTAY_ID is present in events.csv, but not in stays.csv
        # this mean that ICUSTAY_ID in events.csv is not the same as that of stays.csv for the same HADM_ID
        # we drop all such events
        icustay_missing_in_stays += (merged_df['ICUSTAY_ID'] != merged_df['ICUSTAY_ID_r']).sum()
        merged_df = merged_df[(merged_df['ICUSTAY_ID'] == merged_df['ICUSTAY_ID_r'])]

        to_write = merged_df[['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'CHARTTIME', 'ITEMID', 'VALUE', 'VALUEUOM']]
        to_write.to_csv(os.path.join(path, subject, 'events.csv'), index=False)

    assert(could_not_recover == 0)
    print('n_events: {}'.format(n_events))
    print('empty_hadm: {}'.format(empty_hadm))
    print('no_hadm_in_stay: {}'.format(no_hadm_in_stay))
    print('no_icustay: {}'.format(no_icustay))
    print('recovered: {}'.format(recovered))
    print('could_not_recover: {}'.format(could_not_recover))
    print('icustay_missing_in_stays: {}'.format(icustay_missing_in_stays))


def extract_episodes_from_subjects(path):

    log("extracting episodes from subjects")

    glascow_map = {
        220739: {
            "Spontaneously": 4,
            "To Speech": 3,
            "To Pain": 2,
            "None": 1
        },

        223901: {
            "Obeys Commands": 6,
            "Localizes Pain": 5,
            "Flex-withdraws": 4,
            "Abnormal Flexion": 3,
            "Abnormal extension": 2,
            "No response": 1
        },

        223900: {
            "Oriented": 5,
            "Confused": 4,
            "Inappropriate Words": 3,
            "Incomprehensible sounds": 2,
            "No Response-ETT": 1,
            "No Response": 1
        }
    }

    var_map = preprocessing.read_itemid_to_variable_map(os.path.join(pathlib.Path(__file__).parent.resolve(),'itemid_to_variable_map.csv'))
    variables = var_map.VARIABLE.unique()

    for subject_dir in tqdm(os.listdir(path), desc='Iterating over subjects'):
        dn = os.path.join(path, subject_dir)
        try:
            subject_id = int(subject_dir)
            if not os.path.isdir(dn):
                raise Exception
        except:
            continue

        try:
            # reading tables of this subject
            stays = subject.read_stays(os.path.join(path, subject_dir))
            diagnoses = subject.read_diagnoses(os.path.join(path, subject_dir))
            events = subject.read_events(os.path.join(path, subject_dir))
        except:
            sys.stderr.write('Error reading from disk for subject: {}\n'.format(subject_id))
            continue

        glascow = events.loc[(events['ITEMID'] == 220739) | (events['ITEMID'] == 223901) | (events['ITEMID'] == 223900)]
        
        try: #AP
            # Making sure we have a columns to use, otherwise there is no glascow_total to calculate
            if  (glascow.loc[(glascow['ITEMID'] == 220739)]['ITEMID'].count() > 0) and\
                (glascow.loc[(glascow['ITEMID'] == 223901)]['ITEMID'].count() > 0) and\
                (glascow.loc[(glascow['ITEMID'] == 223900)]['ITEMID'].count() > 0):
                
                glascow = glascow[['CHARTTIME', 'ITEMID', 'VALUE']]
                glascow = glascow.pivot(index='CHARTTIME', columns='ITEMID', values='VALUE')

                glascow[220739] = glascow[220739].map(glascow_map[220739])
                glascow[223901] = glascow[223901].map(glascow_map[223901])
                glascow[223900] = glascow[223900].map(glascow_map[223900])

                glascow[198] = np.where((glascow[220739] != np.nan) & (glascow[223901] != np.nan) & (glascow[223900] != np.nan), 
                                        glascow[220739] + glascow[223901] + glascow[223900],
                                        np.nan)

                glascow = glascow.drop(columns=[220739, 223901, 223900])
                glascow_included = events.loc[events['ITEMID'] == 220739]
                glascow_included = glascow_included.merge(glascow, left_on='CHARTTIME', right_on='CHARTTIME')
                glascow_included = glascow_included.drop(columns=['VALUE', 'ITEMID'])
                glascow_included['ITEMID'] = 198
                glascow_included = glascow_included.rename(columns={198: 'VALUE'})
                events = pd.concat([events, glascow_included], ignore_index=True)
        except:  #AP
            pass #AP


        episodic_data = preprocessing.assemble_episodic_data(stays, diagnoses)
        # cleaning and converting to time series
        events = preprocessing.map_itemids_to_variables(events, var_map)
        
        events = preprocessing.clean_events(events)

        if events.shape[0] == 0:
            # no valid events for this subject
            continue
        timeseries = subject.convert_events_to_timeseries(events, variables=variables)

        # extracting separate episodes
        for i in range(stays.shape[0]):
            stay_id = stays.ICUSTAY_ID.iloc[i]
            intime = stays.INTIME.iloc[i]
            outtime = stays.OUTTIME.iloc[i]

            episode = subject.get_events_for_stay(timeseries, stay_id, intime, outtime)
            if episode.shape[0] == 0:
                # no data for this episode
                continue

            episode = subject.add_hours_elapsed_to_events(episode, intime).set_index('HOURS').sort_index(axis=0)
            if stay_id in episodic_data.index:
                episodic_data.loc[stay_id, 'Weight'] = subject.get_first_valid_from_timeseries(episode, 'Weight')
                episodic_data.loc[stay_id, 'Height'] = subject.get_first_valid_from_timeseries(episode, 'Height')
            episodic_data.loc[episodic_data.index == stay_id].to_csv(os.path.join(path, subject_dir,
                                                                                'episode{}.csv'.format(i+1)),
                                                                    index_label='Icustay')
            columns = list(episode.columns)
            columns_sorted = sorted(columns, key=(lambda x: "" if x == "Hours" else x))
            episode = episode[columns_sorted]

            episode.to_csv(os.path.join(path, subject_dir, 'episode{}_timeseries.csv'.format(i+1)),
                        index_label='Hours')


def split_train_and_test(path):
    log("splitting train and test")

    def move_to_partition(patients, partition):
        if not os.path.exists(os.path.join(path, partition)):
            os.mkdir(os.path.join(path, partition))
        for patient in patients:
            src = os.path.join(path, patient)
            dest = os.path.join(path, partition, patient)
            shutil.move(src, dest)

    test_set = set()
    
    folders = os.listdir(path)
    folders = list((filter(str.isdigit, folders)))
    
    train, test = train_test_split(folders, random_state=42, test_size=0.2)

    train_df = pd.DataFrame(train, columns=['subject_id'])
    train_df['is_test'] = 0
    test_df = pd.DataFrame(test, columns=['subject_id'])
    test_df['is_test'] = 1
    df_out = pd.concat([train_df, test_df])

    df_out = df_out.sort_values(by=['subject_id'])
    df_out['is_test'] = df_out.is_test.astype(int)
    df_out.to_csv(os.path.join(path, 'testset.csv'), columns=['subject_id', 'is_test'], index=False, header=False)

    with open(os.path.join(path, 'testset.csv'), "r") as test_set_file:
        for line in test_set_file:
            x, y = line.split(',')
            if int(y) == 1:
                test_set.add(x)

    train_patients = [x for x in folders if x not in test_set]
    test_patients = [x for x in folders if x in test_set]

    assert len(set(train_patients) & set(test_patients)) == 0

    print("putting subjects in train...")
    move_to_partition(train_patients, "train")
    print("putting subjects in test...")
    move_to_partition(test_patients, "test")


def create_time_dep_data(input_path, output_path):
    log("creating time-dependent data")

    def process_partition(partition, eps=1e-6):
        output_dir = os.path.join(output_path, partition)
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        xy_pairs = []
        patients = list(filter(str.isdigit, os.listdir(os.path.join(input_path, partition))))
        for patient in tqdm(patients, desc='Iterating over patients in {}'.format(partition)):
            patient_folder = os.path.join(input_path, partition, patient)
            patient_ts_files = list(filter(lambda x: x.find("timeseries") != -1, os.listdir(patient_folder)))

            for ts_filename in patient_ts_files:
                ts_df = pd.read_csv(os.path.join(patient_folder, ts_filename))
                lb_filename = ts_filename.replace("_timeseries", "")
                label_df = pd.read_csv(os.path.join(patient_folder, lb_filename))

                # empty label file
                if label_df.shape[0] == 0:
                    continue

                mortality = int(label_df.iloc[0]["Mortality"])
                los = 24.0 * label_df.iloc[0]['Length of Stay']  # in hours
                if pd.isnull(los):
                    print("\n\t(length of stay is missing)", patient, ts_filename)
                    continue

                # if los < n_hours - eps:
                #     continue
                ts_df = ts_df[(-eps<ts_df['Hours']) & (ts_df['Hours']<los + eps)].reset_index(drop=True)
                # no measurements in ICU
                if len(ts_df) == 0:
                    print("\n\t(no events in ICU) ", patient, ts_filename)
                    continue

                ###########################         
                label_df.rename(columns={'Height':'Height_static', 'Weight':'Weight_static'}, inplace = True)
                label_df = pd.DataFrame(np.repeat(label_df.values, len(ts_df), axis=0), columns=label_df.columns)
                df = pd.concat([ts_df, label_df], axis=1)

                output_ts_filename = patient + "_" + ts_filename
                df.to_csv(os.path.join(output_dir, output_ts_filename), index=None)
                '''
                with open(os.path.join(output_dir, output_ts_filename), "w") as outfile:
                    outfile.write(header)
                    for line in ts_lines:
                        outfile.write(line)
                '''

                xy_pairs.append((output_ts_filename, mortality, los))

        print("Number of created samples:", len(xy_pairs))
        if partition == "train":
            random.shuffle(xy_pairs)
        if partition == "test":
            xy_pairs = sorted(xy_pairs)

        with open(os.path.join(output_dir, "listfile.csv"), "w") as listfile:
            listfile.write('stay,y_true,los\n')
            for (x, y, t) in xy_pairs:
                listfile.write('{},{:d},{:f}\n'.format(x, y, t))


    if not os.path.exists(output_path):
        os.makedirs(output_path)

    util.run_as_Ps ([process_partition], [{"partition":mode} for mode in ["train", "test"]])


def fldr_aggr_t_strt_t_end(path):
    def aggr_fldrs(mode):
        #print ('haha')
        addr = os.path.join(path, mode)

        def read_data(name):
            if not os.path.isfile(os.path.join(addr, name)):
                return None
            data = pd.read_csv(os.path.join(addr, name))
            if 'Icustay' in data:
                data['subject'] = name.split('_')[0]
                data['Icustay'] = data['Icustay'].fillna(data['Icustay'].mean())
                return data
            return None

        DATA = [read_data(name) for name in os.listdir(addr)]
        DATA = [data for data in DATA if data is not None]
        DATA = pd.concat(DATA)

        DATA.sort_values(by=['subject', 'Icustay', 'Hours'], inplace=True)
        DATA.to_csv(os.path.join(path, f"{mode}.csv"), index=None)

    def to_t_start_t_end(mode):
        # get average of measurements if they happen at the same time
        def average_on_t(data):
            cols = [col_ for col_ in data.columns if col_!='Hours']
            grps = data['Hours'].ne(data['Hours'].shift()).cumsum()
            gpby = data.groupby(grps)
            return gpby.agg(dict(Hours='min')).join(gpby[cols].mean(numeric_only=True)).reset_index(drop=True)

        def one_id_to_t_start_t_end(data_):
            data_ = data_.sort_values(by=['Hours'])
            data_ = average_on_t(data_)
            data_['t_end'] = data_['Hours'].shift(-1)
            data_ = data_.iloc[:-1,:]
            return data_

        data = pd.read_csv(os.path.join(path, f"{mode}.csv"))
        data = util.put_cols_first(data, ['subject', 'Icustay'])

        data = data.groupby(['subject', 'Icustay']).progress_apply(one_id_to_t_start_t_end).reset_index(drop=True)
        data.rename(columns={'Hours':'t_start'}, inplace=True)
        data = util.put_cols_first(data, ['subject', 'Icustay', 't_start', 't_end'])
        data.to_csv(os.path.join(path, f"tstart_tend_{mode}.csv"), index=None)

    for mode in ["train", "test"]:
        aggr_fldrs(mode)
    util.run_as_Ps ([aggr_fldrs],       [{"mode":mode} for mode in ["train", "test"]])    
    util.run_as_Ps ([to_t_start_t_end], [{"mode":mode} for mode in ["train", "test"]])


def add_events(mimic4_path, in_path):

    NUM_DEC_PLACES = 3

    def clip(data, upper):
        data           = data[data['t_start']<upper]
        data['t_end'].clip(upper=upper, inplace=True)
        return data.reset_index(drop=True)

    def fix_subj_icustay_t_start_end(data):
        data[["subject", "Icustay"]] = data[["subject", "Icustay"]].astype('int')
        data[['t_start', 't_end']]   = data[['t_start', 't_end']  ].round(NUM_DEC_PLACES)
        return data

    def read_X(mode, input_path):
        data = pd.read_csv(os.path.join(input_path, f"tstart_tend_{mode}.csv"))#f"boxhed_data_{mode}.csv"))
        data.drop(columns = [col for col in data.columns if (col.startswith("Diagnosis") or (col in ['Mortality', 'Length of Stay']))], inplace=True)
        data = util.put_cols_first(data, ['subject', 'Icustay', 't_start', 't_end'])
        data = fix_subj_icustay_t_start_end(data)

        return data.reset_index(drop=True)


    def read_event_data(mimic4_addr, itemid):
        data = pd.read_csv(os.path.join(mimic4_addr, "icu/procedureevents.csv"))
        data = data[data['itemid']==itemid]
        data.rename(columns={"subject_id": "subject", "stay_id": "Icustay", "starttime": "t_start", "endtime": "t_end"}, inplace=True)
        data = fix_subj_icustay_t_start_end(data)
        return data

    def read_stays(mimic4_addr):
        data = pd.read_csv(os.path.join(mimic4_addr, "icu/icustays.csv"))[['subject_id', 'stay_id', 'intime']]
        return data.rename(columns={'subject_id': 'subject', 'stay_id':'Icustay'})


    def get_event_start_end(X, stays, event_data, mode):
        unique_subj_icustay = X[['subject','Icustay']].drop_duplicates()
        event_data          = event_data[['subject','Icustay', 't_start', 't_end']]
        evnt_abs_start_end = pd.merge(unique_subj_icustay, event_data, on=['subject','Icustay'])

        def get_rel_evnt_start_end(row):
            #stays          = pd.read_csv(f"/home/data/datasets/mimic_iv_AP/{mode}/{row['subject']}/stays.csv")
            icu_start_time = datetime.fromisoformat(stays[(stays['subject']==row['subject']) & (stays['Icustay']==row['Icustay'])]['intime'].iat[0])
            row['t_start'] = round((datetime.fromisoformat(row['t_start']) - icu_start_time).total_seconds()/60/60, NUM_DEC_PLACES)
            row['t_end']   = round((datetime.fromisoformat(row['t_end'])   - icu_start_time).total_seconds()/60/60, NUM_DEC_PLACES)
            return row

        evnt_rel_start_end    = evnt_abs_start_end.apply(get_rel_evnt_start_end, axis = 1)
        evnt_rel_start_end    = evnt_rel_start_end[evnt_rel_start_end['t_start']>=0].reset_index(drop=True)
        evnt_rel_start_end.sort_values(by=['subject', 'Icustay', 't_start'], inplace=True, ignore_index=True)

        return evnt_rel_start_end

    def aggr_subsequent_iv(iv_data, max_time):
        #https://stackoverflow.com/questions/46732760/merge-rows-pandas-dataframe-based-on-condition
        def aggr_conseq(data):
            return data.groupby(((data['t_start']  - data['t_end'].shift(1)) > max_time).cumsum()).agg({'subject':min, 'Icustay':min, 't_start':min, 't_end':max})
        
        return iv_data.groupby(['subject', 'Icustay']).progress_apply(aggr_conseq).reset_index(drop=True)
        '''
        out = []
        for _, icustay_iv_data in :
            out.append(icustay_iv_data.groupby(((icustay_iv_data['t_start']  - icustay_iv_data['t_end'].shift(1)) > max_time).cumsum()).agg({'subject':min, 'Icustay':min, 't_start':min, 't_end':max}))
        return pd.concat(out, ignore_index=True)
        '''


    def break_row_into_two(row, t):
        row             = pd.DataFrame(row).transpose()
        additional_rows = pd.DataFrame(np.repeat(row.values, 2, axis=0), columns=row.columns)
        additional_rows.loc[0, "t_end"]   = t
        additional_rows.loc[1, "t_start"] = t           
        return additional_rows


    def break_row_into_three(row, t1, t2):
        row             = pd.DataFrame(row).transpose()
        additional_rows = pd.DataFrame(np.repeat(row.values, 3, axis=0), columns=row.columns)
        additional_rows.loc[0, "t_end"]     = t1
        additional_rows.loc[1, "t_start"]   = t1
        additional_rows.loc[1, "t_end"]     = t2  
        additional_rows.loc[2, "t_start"]   = t2    
        return additional_rows  

    def group_to_dict(data, cols):
        return {key: data for key, data in data.groupby(cols)}

    def add_events_(X, stays, iv_data, trach_data, mode, CUTOFF, IV_max_time_diff, path):

        iv_rel_start_end    = get_event_start_end(X, stays, iv_data, mode)
        trach_rel_start_end = get_event_start_end(X, stays, trach_data, mode)

        # cut off
        [X, iv_rel_start_end, trach_rel_start_end] = [clip(data, CUTOFF) for data in [X, iv_rel_start_end, trach_rel_start_end]]
        X                   = X[X['t_start']<X['t_end']]

        iv_rel_start_end    = aggr_subsequent_iv(iv_rel_start_end, IV_max_time_diff)       

        trachs = group_to_dict(trach_rel_start_end, ['subject', 'Icustay'])
        ivs    = group_to_dict(iv_rel_start_end,    ['subject', 'Icustay'])

        def add_IV_trach(data):  
            subject, icustay = data['subject'].iloc[0], data['Icustay'].iloc[0]

            #censoring based on trach and invasive ventilation
            if (subject, icustay) in trachs:
                data = clip(data, trachs[(subject, icustay)].head(1)['t_start'].item())

            if (subject, icustay) in ivs:
                iv = ivs[(subject, icustay)]
            else:
                return data

            data.reset_index(drop=True, inplace=True)

            # filtering based on IV
            for i in range(len(iv)-1, -1, -1):
                data = data[data['t_start']>=0]
                iv_ = iv.iloc[i,:]
        
                impacted_idxs = data.index[(data['t_start']<=iv_['t_end']) & (data['t_end']>=iv_['t_start'])].tolist()
                
                if len(impacted_idxs)==0:
                    continue

                if mode == 'train':
                    # regardless of being impacted only at the very end or in the middle
                    orig_t_end                          = data.loc[impacted_idxs[0],  't_end']
                    orig_delta                          = data.loc[impacted_idxs[0],  'delta']
                    data.loc[impacted_idxs[0], 't_end'] = iv_['t_start']
                    data.loc[impacted_idxs[0], 'delta'] = 1
                    
                    if len(impacted_idxs)>1: # meaning there has been more than one row impacted
                        if  data.loc[impacted_idxs[-1], 't_end'] == iv_['t_end']: #if impacted till the very end
                            data.drop(index=impacted_idxs[-1], inplace = True)
                        else: # impacted till somewhere in the middle of the epoch
                            data.loc[impacted_idxs[-1], 't_start'] = iv_['t_end']
                        #dropping impacted rows in the middle
                        data.drop(index=impacted_idxs[1:-1], inplace = True)  
                    else: #len(impacted_idxs) == 1, meaning only one epoch has been impacted, so it should be cut into 2
                        additional_row            = pd.DataFrame(data.loc[impacted_idxs[0],  :]).transpose()
                        additional_row['t_start'] = iv_['t_end']
                        additional_row['t_end']   = orig_t_end
                        additional_row['delta']   = orig_delta
                        data                      = pd.concat([data.loc[:impacted_idxs[0], :], additional_row, data.loc[impacted_idxs[0]+1:, :]]) #stitching all together

                    data = data[data['t_start']<data['t_end']]

                elif mode == 'test':
                    to_concat = []

                    if len(impacted_idxs)>1:
                        if  data.loc[impacted_idxs[0],  "t_end"] == iv_['t_start']:
                            data.loc[impacted_idxs[0],  "delta"] = 1
                            to_concat.append(data.loc[:impacted_idxs[0], :])
                        else:
                            additional_rows = break_row_into_two(data.loc[impacted_idxs[0],  :], iv_['t_start'])
                            additional_rows.loc[0, "delta"] = 1
                            additional_rows.loc[1, "Y"]     = 0
                            to_concat.extend([data.loc[:impacted_idxs[0]-1, :], additional_rows])

                        data.loc[impacted_idxs[1:-1],  "Y"] = 0
                        to_concat.append(data.loc[impacted_idxs[1:-1],  :])

                        if  data.loc[impacted_idxs[-1], "t_start"] != iv_['t_end']: # ending in the middle
                            additional_rows = break_row_into_two(data.loc[impacted_idxs[-1],  :], iv_['t_end'])
                            additional_rows.loc[0, "Y"]     = 0
                            to_concat.extend([additional_rows, data.loc[impacted_idxs[-1]+1:, :]])
                            #data = pd.concat([data.loc[:impacted_idxs[0]-1, :]]+rows_to_add+[data.loc[impacted_idxs[-1]+1:, :]])
                        else:
                            to_concat.append(data.loc[impacted_idxs[-1]:, :])
                        
                    else:
                        additional_rows = break_row_into_three(data.loc[impacted_idxs[0],  :], iv_['t_start'], iv_['t_end'])
                        additional_rows.loc[0, "delta"] = 1
                        additional_rows.loc[1, "delta"] = 0
                        additional_rows.loc[1, "Y"]     = 0

                        to_concat = [data.loc[:impacted_idxs[0]-1, :], additional_rows, data.loc[impacted_idxs[0]+1:, :]]

                    
                    data = pd.concat(to_concat)
                    data.reset_index(drop=True, inplace=True)
                
                data.reset_index(drop=True, inplace=True)
                    
            #computing #past_IVs, t_from_last_IV_t_start, t_from_last_IV_t_end. Could have been incorporated into above, but wanted to keep things simpler
            for _, iv_ in iv.iterrows():
                after_start = data['t_start']>=iv_['t_start']
                data['#past_IVs'][after_start]  += 1
                data['t_from_last_IV_t_start'][after_start] = data['t_start'][after_start]-iv_['t_start']

                after_end   = data['t_start']>=iv_['t_end']
                data['t_from_last_IV_t_end'  ][after_end] = data['t_start'][after_end]-iv_['t_end']

                data = data[data['t_end']>data['t_start']]
            
            return data
        
        X['#past_IVs'], X['t_from_last_IV_t_start'], X['t_from_last_IV_t_end'], X['delta'] = [0, np.nan, np.nan, 0]
        if mode=="test":
            X['Y'] = 1
        output = X.groupby(['subject', 'Icustay']).progress_apply(add_IV_trach).reset_index(drop=True)
        output = output[output['t_start']<output['t_end']]
        output.to_csv(os.path.join(path, f"./mimic_iv_{mode}.csv"), index=None)

    CUTOFF           = np.inf#5*24 #hours    
    IV_max_time_diff = 0.5 #half an hour
    iv_data          = read_event_data(mimic4_path, 225792)
    trach_data       = read_event_data(mimic4_path, 225448)
    stays            = read_stays(mimic4_path)

    util.run_as_Ps([add_events_], [
        {"X":               read_X(mode, in_path), 
        "stays": stays,
        "iv_data":          iv_data, 
        "trach_data":       trach_data, 
        "mode":             mode, 
        "CUTOFF":           CUTOFF, 
        "IV_max_time_diff": IV_max_time_diff, 
        "path":             in_path} 
            for mode in ["train", "test"]])


def clean_data(in_path, out_path):
    def read_mimic_iv_data(mode):
        data = pd.read_csv(os.path.join(in_path, f"mimic_iv_{mode}.csv"))
        
        data.drop(columns=["Height", "Weight"], inplace=True)
        data.rename(columns={"Height_static":"Height", "Weight_static":"Weight"}, inplace=True)

        first_cols     = ["subject", "Icustay", "t_start", "t_end"]
        categoricals   = ['Capillary refill rate', 'Ethnicity', 'Gender']
        last_cols      = ["delta", "Y"] if "Y" in data else ["delta"]
        non_numericals = first_cols + categoricals + last_cols

        numericals   = [col for col in data.columns if not col in non_numericals]
        data         = pd.concat([
                        data[first_cols],
                        pd.get_dummies(data[categoricals], columns=categoricals, dummy_na=True, drop_first=True),
                        data[numericals],
                        data[last_cols]], axis=1)

        data = data[data['t_start']<data['t_end']]
        data = data[data['t_start']>=0]
        data.drop(columns=["subject"], inplace=True)
        data.rename(columns={"Icustay":"ID"}, inplace=True)
        
        return data.reset_index(drop=True)


    def ffill(data):
        return         data.groupby(['ID']).progress_apply(lambda x: x.ffill(axis=0)).reset_index(drop=True)    


    def remove_shorter_than(data, min_dur):
        print (f"truncating the first {min_dur} hours...")
        if min_dur == 0:
            return data

        def truncate(ep_data):
            ep_data            = ep_data[ep_data['t_end']>min_dur]
            ep_data['t_start'] = ep_data['t_start'].clip(lower=min_dur)
            return ep_data

        data['EP_ID'] = util.get_episode_IDs(*[data[col].values for col in ['ID', 'delta']])
        data          = data.groupby('EP_ID').progress_apply(truncate).reset_index(drop=True)
        return data.drop(columns=['EP_ID'])


    def rm_no_risk_eps(data):
        print ("removing no-risk episodes...")
        def at_risk(ep_data):
            return ep_data['Y'].max()>0

        data['EP_ID'] =  util.get_episode_IDs(*[data[col].values for col in ['ID', 'delta']])
        data = data.groupby(['EP_ID']).filter(
            at_risk).reset_index(drop=True)

        return data.drop(columns=['EP_ID'])


    def time_varying_mimic_iv_data():
        monitor_time              = 24
        train_data, test_data     = [read_mimic_iv_data(data)                for data in ["train",     "test"]]
        train_data, test_data     = [ffill(data)                             for data in [train_data, test_data]]
        train_data, test_data     = [remove_shorter_than(data, monitor_time) for data in [train_data, test_data]]
        test_data                 =  rm_no_risk_eps(test_data)

        return {"train": train_data, "test": test_data}

    for mode, data in time_varying_mimic_iv_data().items():
        data.to_csv(os.path.join(out_path, f"./mimic_iv_{mode}.csv"), index=None)
    

import time
def time_now():
    return time.time()

class timer:

    def __init__(self):
        self.t_start = time_now()

    def get_dur(self):
        return round(time_now()-self.t_start, 3)


def extract_mimic_iv(mimic4_path, tmp_path, output_path):
    extract_subjects(mimic4_path, tmp_path)
    validate_events(tmp_path)
    extract_episodes_from_subjects(tmp_path)
    split_train_and_test(tmp_path)
    create_time_dep_data(tmp_path, tmp_path)
    fldr_aggr_t_strt_t_end(tmp_path)
    add_events (mimic4_path, tmp_path)
    clean_data (tmp_path, output_path)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Extract per-subject data from MIMIC-IV CSV files.')
    parser.add_argument('mimic4_path', type=str, help='Directory containing MIMIC-IV CSV files.')
    parser.add_argument('tmp_path',    type=str, help='Directory where intermediate (temporary) files are stored.')
    parser.add_argument('output_path', type=str, help='Directory where the created dataset is saved.')
    args, _ = parser.parse_known_args()

    extract_mimic_iv(args.mimic4_path, args.tmp_path, args.output_path)