from __future__ import absolute_import
from __future__ import print_function

import csv
import numpy as np
import os
import pandas as pd
from tqdm import tqdm

from . import util

# TODO: change to mimic4_path

# TODO: come back to this
# No DOB/DOD, will have to merge
# MISSING: dob, dod
def read_patients_table(mimic4_path):
    pats = util.dataframe_from_csv(os.path.join(mimic4_path, 'core/patients.csv'))
    pats = pats[['subject_id', 'gender', 'anchor_age', 'anchor_year', 'dod']]

    pats['DOB'] = pats['anchor_year'] - pats['anchor_age']
    pats['DOB'] = pd.to_datetime(pats.DOB, format="%Y")

    pats = pats[['subject_id', 'gender', 'DOB']]
    pats.columns = ['SUBJECT_ID', 'GENDER', 'DOB']
    return pats

# MISSING: diagnoses
# TODO: look into ways deaths are handled
def read_admissions_table(mimic4_path):
    admits = util.dataframe_from_csv(os.path.join(mimic4_path, 'core/admissions.csv'))
    admits = admits[['subject_id', 'hadm_id', 'admittime', 'dischtime', 'deathtime', 'ethnicity']]
    
    admits.admittime = pd.to_datetime(admits.admittime)
    admits.dischtime = pd.to_datetime(admits.dischtime)
    admits.deathtime = pd.to_datetime(admits.deathtime)

    admits['DOD'] = admits['deathtime']

    admits.columns = ['SUBJECT_ID', 'HADM_ID', 'ADMITTIME', 'DISCHTIME', 'DEATHTIME', 'ETHNICITY', 'DOD']
    return admits


# TODO: might be best to get rid of DBSOURCE and fix later logic
# TODO: order might matter, might need to come back to this
def read_icustays_table(mimic4_path):
    stays = util.dataframe_from_csv(os.path.join(mimic4_path, 'icu/icustays.csv'))
    stays['DBSOURCE'] = "metavision"
    stays.intime = pd.to_datetime(stays.intime)
    stays.outtime = pd.to_datetime(stays.outtime)

    stays.columns = ['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'FIRST_CAREUNIT', 'LAST_CAREUNIT', 'INTIME', 'OUTTIME', 'LOS', 'DBSOURCE']
    return stays


# TODO: work in transfers table
# No entry for wardID, could also get rid of entries with multiple stays in one hospital admission
def remove_icustays_with_transfers(stays):
    stays = stays[(stays.FIRST_CAREUNIT == stays.LAST_CAREUNIT)]
    return stays[['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'LAST_CAREUNIT', 'DBSOURCE', 'INTIME', 'OUTTIME', 'LOS']]


def merge_on_subject(table1, table2):
    return table1.merge(table2, how='inner', left_on=['SUBJECT_ID'], right_on=['SUBJECT_ID'])


def merge_on_subject_admission(table1, table2):
    return table1.merge(table2, how='inner', left_on=['SUBJECT_ID', 'HADM_ID'], right_on=['SUBJECT_ID', 'HADM_ID'])


# Q: Is some logic before this redundant because of this function
def filter_admissions_on_nb_icustays(stays, min_nb_stays=1, max_nb_stays=1):
    to_keep = stays.groupby('HADM_ID').count()[['ICUSTAY_ID']].reset_index()
    to_keep = to_keep[(to_keep.ICUSTAY_ID >= min_nb_stays) & (to_keep.ICUSTAY_ID <= max_nb_stays)][['HADM_ID']]
    stays = stays.merge(to_keep, how='inner', left_on='HADM_ID', right_on='HADM_ID')
    return stays


def add_age_to_icustays(stays):
    stays['AGE'] = (stays.INTIME - stays.DOB).apply(lambda s: s / np.timedelta64(1, 's')) / 60./60/24/365
    return stays


# TODO: check to see how this compares to the 'expired' tag
def add_inhospital_mortality_to_icustays(stays):
    mortality = stays.DOD.notnull() & ((stays.ADMITTIME <= stays.DOD) & (stays.DISCHTIME >= stays.DOD))
    mortality = mortality | (stays.DEATHTIME.notnull() & ((stays.ADMITTIME <= stays.DEATHTIME) & (stays.DISCHTIME >= stays.DEATHTIME)))
    stays['MORTALITY'] = mortality.astype(int)
    stays['MORTALITY_INHOSPITAL'] = stays['MORTALITY']
    return stays


def add_inunit_mortality_to_icustays(stays):
    mortality = stays.DOD.notnull() & ((stays.INTIME <= stays.DOD) & (stays.OUTTIME >= stays.DOD))
    mortality = mortality | (stays.DEATHTIME.notnull() & ((stays.INTIME <= stays.DEATHTIME) & (stays.OUTTIME >= stays.DEATHTIME)))
    stays['MORTALITY_INUNIT'] = mortality.astype(int)
    return stays


def filter_icustays_on_age(stays, min_age=18, max_age=np.inf):
    stays = stays[(stays.AGE >= min_age) & (stays.AGE <= max_age)]
    return stays


# TODO: Filter on icd_10?
# TODO: delete stuff new to mimic4?
def read_icd_diagnoses_table(mimic4_path):
    # Going to keep icd_codes and icd_version
    codes = util.dataframe_from_csv(os.path.join(mimic4_path, 'hosp/d_icd_diagnoses.csv'))
    # codes = codes[['icd_code', 'long_title']]
    diagnoses = util.dataframe_from_csv(os.path.join(mimic4_path, 'hosp/diagnoses_icd.csv'))
    diagnoses = diagnoses.merge(codes, how='inner', left_on=['icd_code', 'icd_version'], right_on=['icd_code', 'icd_version'])
    diagnoses[['subject_id', 'hadm_id', 'seq_num']] = diagnoses[['subject_id', 'hadm_id', 'seq_num']].astype(int)
    diagnoses.columns = ['SUBJECT_ID', 'HADM_ID', 'SEQ_NUM', 'ICD_CODE', 'ICD_VERSION', 'LONG_TITLE']


    # Only keeping diagnoses with ICD-9 codes
    # TODO: Look into how to incorporate ICD-10
    is_icd9 = diagnoses['ICD_VERSION'] == 9
    diagnoses = diagnoses[is_icd9]


    # TODO: check if this works
    # removing whitespace in ICD_CODES
    diagnoses['ICD_CODE'] = diagnoses['ICD_CODE'].str.strip()


    return diagnoses


def filter_diagnoses_on_stays(diagnoses, stays):
    return diagnoses.merge(stays[['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID']].drop_duplicates(), how='inner',
                           left_on=['SUBJECT_ID', 'HADM_ID'], right_on=['SUBJECT_ID', 'HADM_ID'])


def count_icd_codes(diagnoses, output_path=None):
    codes = diagnoses[['ICD_CODE', 'LONG_TITLE']].drop_duplicates().set_index('ICD_CODE')
    codes['COUNT'] = diagnoses.groupby('ICD_CODE')['ICUSTAY_ID'].count()
    codes.COUNT = codes.COUNT.fillna(0).astype(int)
    codes = codes[codes.COUNT > 0]
    if output_path:
        codes.to_csv(output_path, index_label='ICD_CODE')
    return codes.sort_values('COUNT', ascending=False).reset_index()


def break_up_stays_by_subject(stays, output_path, subjects=None):
    subjects = stays.SUBJECT_ID.unique() if subjects is None else subjects
    nb_subjects = subjects.shape[0]
    for subject_id in tqdm(subjects, total=nb_subjects, desc='Breaking up stays by subjects'):
        dn = os.path.join(output_path, str(subject_id))
        try:
            os.makedirs(dn)
        except:
            pass

        stays[stays.SUBJECT_ID == subject_id].sort_values(by='INTIME').to_csv(os.path.join(dn, 'stays.csv'),
                                                                              index=False)

def break_up_diagnoses_by_subject(diagnoses, output_path, subjects=None):
    subjects = diagnoses.SUBJECT_ID.unique() if subjects is None else subjects
    nb_subjects = subjects.shape[0]
    for subject_id in tqdm(subjects, total=nb_subjects, desc='Breaking up diagnoses by subjects'):
        dn = os.path.join(output_path, str(subject_id))
        try:
            os.makedirs(dn)
        except:
            pass

        diagnoses[diagnoses.SUBJECT_ID == subject_id].sort_values(by=['ICUSTAY_ID', 'SEQ_NUM'])\
                                                     .to_csv(os.path.join(dn, 'diagnoses.csv'), index=False)


def read_events_table_by_row(mimic4_path, table):
    # nb_rows = {'chartevents': 330712484, 'labevents': 27854056, 'outputevents': 4349219}
    nb_rows = {'icu/chartevents': 327363275, 'hosp/labevents': 122289829, 'icu/outputevents': 4248829}
    reader = csv.DictReader(open(os.path.join(mimic4_path, table + '.csv'), 'r'))
    for i, row in enumerate(reader):
        if 'stay_id' not in row:
            row['stay_id'] = ''
        yield row, i, nb_rows[table.lower()]


def read_events_table_and_break_up_by_subject(mimic4_path, table, output_path,
                                              items_to_keep=None, subjects_to_keep=None):
    obs_header = ['SUBJECT_ID', 'HADM_ID', 'ICUSTAY_ID', 'CHARTTIME', 'ITEMID', 'VALUE', 'VALUEUOM']
    if items_to_keep is not None:
        items_to_keep = set([str(s) for s in items_to_keep])
    if subjects_to_keep is not None:
        subjects_to_keep = set([str(s) for s in subjects_to_keep])

    class DataStats(object):
        def __init__(self):
            self.curr_subject_id = ''
            self.curr_obs = []

    data_stats = DataStats()

    def write_current_observations():
        dn = os.path.join(output_path, str(data_stats.curr_subject_id))
        try:
            os.makedirs(dn)
        except:
            pass
        fn = os.path.join(dn, 'events.csv')
        if not os.path.exists(fn) or not os.path.isfile(fn):
            f = open(fn, 'w')
            f.write(','.join(obs_header) + '\n')
            f.close()
        w = csv.DictWriter(open(fn, 'a'), fieldnames=obs_header, quoting=csv.QUOTE_MINIMAL)
        w.writerows(data_stats.curr_obs)
        data_stats.curr_obs = []

    # NOTE: Will have to update these on MIMIC-IV changes
    nb_rows_dict = {'icu/chartevents': 327363275, 'hosp/labevents': 122289829, 'icu/outputevents': 4248829}

    nb_rows = nb_rows_dict[table]

    assert subjects_to_keep is not None
    for row, row_no, _ in tqdm(read_events_table_by_row(mimic4_path, table), total=nb_rows,
                                                        desc='Processing {} table'.format(table)):

        if (row['subject_id'] not in subjects_to_keep):
            continue
        if (items_to_keep is not None) and (row['itemid'] not in items_to_keep):
            continue

        row_out = {'SUBJECT_ID': row['subject_id'],
                   'HADM_ID': row['hadm_id'],
                   'ICUSTAY_ID': '' if 'stay_id' not in row else row['stay_id'],
                   'CHARTTIME': row['charttime'],
                   'ITEMID': row['itemid'],
                   'VALUE': row['value'],
                   'VALUEUOM': row['valueuom']}

        if data_stats.curr_subject_id != '' and data_stats.curr_subject_id != row['subject_id']:
            write_current_observations()
        data_stats.curr_obs.append(row_out)
        data_stats.curr_subject_id = row['subject_id']

    if data_stats.curr_subject_id != '':
        write_current_observations()