###################################################################
# Replication file for Cattaneo, Feng, Palomba, and Titiunik (2022)
###################################################################

########################################
# Load SCPI_PKG package
import pandas
import numpy
import random
import os
from warnings import filterwarnings
from plotnine import ggtitle, ggsave

from scpi_pkg.scdata import scdata
from scpi_pkg.scdataMulti import scdataMulti
from scpi_pkg.scest import scest
from scpi_pkg.scpi import scpi
from scpi_pkg.scplot import scplot
from scpi_pkg.scplotMulti import scplotMulti

filterwarnings("ignore")

########################################
# One feature (gdp)
########################################

########################################
# Load database
data = pandas.read_csv("Data/scpi_germany.csv")

########################################
# Set options for data preparation
id_var = "country"
outcome_var = "gdp"
time_var = "year"
period_pre = numpy.arange(1960, 1991)
period_post = numpy.arange(1991, 2004)
unit_tr = "West Germany"
unit_co = list(set(data[id_var].to_list()))
unit_co = [cou for cou in unit_co if cou != "West Germany"]
constant = True
cointegrated_data = True

data_prep = scdata(df = data, id_var = id_var, time_var = time_var,
                   outcome_var = outcome_var, period_pre = period_pre,
                   period_post = period_post, unit_tr = unit_tr,
                   unit_co = unit_co, cointegrated_data = cointegrated_data,
                   constant = constant)

####################################
# SC - point estimation with simplex
est_si = scest(data_prep, w_constr = {"name": "simplex"})
print(est_si)

####################################
# Set options for inference
w_constr = {"name": "simplex", "Q": 1}
u_missp = True
u_sigma = "HC1"
u_order = 1
u_lags = 0
e_method = "gaussian"
e_order = 1
e_lags = 0
sims = 1000
cores = 1

for mtd in ["lasso", "simplex", "ridge", "L1-L2", "ols"]:
    if mtd in ["ridge", "L1-L2"]:
        lgapp = "generalized"
    else:
        lgapp = "linear"
    numpy.random.seed(8894)
    pi_si = scpi(data_prep, sims = sims, w_constr = {"name": mtd},
                 u_order = u_order, u_lags = u_lags, e_order = e_order,
                 e_lags = e_lags, e_method = e_method, u_missp = u_missp,
                 lgapp = lgapp, u_sigma = u_sigma, cores = cores)

    plot = scplot(pi_si, x_lab = "Year", e_method = e_method,
                  y_lab = "GDP per capita (thousand US dollars)")

    plot = plot + ggtitle("")
    print(plot)


########################################
# Multiple features (gdp, trade)
########################################

########################################
# Load database
data = pandas.read_csv("Data/scpi_germany.csv")

########################################
# Set options for data preparation
id_var = "country"
outcome_var = "gdp"
time_var = "year"
period_pre = numpy.arange(1960, 1991)
period_post = numpy.arange(1991, 2004)
unit_tr = "West Germany"
unit_co = list(set(data[id_var].to_list()))
unit_co = [cou for cou in unit_co if cou !=  "West Germany"]
constant = False
cointegrated_data = True
cov_adj = [["constant"], ["constant"]]

data_prep = scdata(df = data, id_var = id_var, time_var = time_var,
                   outcome_var = outcome_var, period_pre = period_pre,
                   period_post = period_post, unit_tr = unit_tr, constant = constant,
                   unit_co = unit_co, cointegrated_data = cointegrated_data,
                   features = ["gdp", "trade"], cov_adj = cov_adj)

####################################
# SC - point estimation with simplex
est_si = scest(data_prep, w_constr = {"name": "simplex"})
print(est_si)

####################################
# Set options for inference
w_constr = {"name": "simplex", "Q": 1}
u_missp = True
u_sigma = "HC1"
u_order = 1
u_lags = 0
e_method = "gaussian"
e_order = 1
e_lags = 0
sims = 1000
cores = 1

for mtd in ["simplex", "lasso", "ridge", "L1-L2", "ols"]:
    if mtd in ["ridge", "L1-L2"]:
        lgapp = "generalized"
    else:
        lgapp = "linear"

    numpy.random.seed(8894)
    pi_si = scpi(data_prep, sims = sims, w_constr = {"name": mtd},
                 u_order = u_order, u_lags = u_lags, e_order = e_order,
                 e_lags = e_lags, e_method = e_method, u_missp = u_missp,
                 lgapp = lgapp, u_sigma = u_sigma, cores = cores)

    plot = scplot(pi_si, x_lab = "Year", e_method = e_method, joint = True,
                  y_lab = "GDP per capita (thousand US dollars)")

    plot = plot + ggtitle("")
    print(plot)
