# -*- coding: utf-8 -*-


# run:   /s/python-3.3.3/bin/python3 code.py
import sys; sys.path.insert(0, "site-packages")
import time

import numpy as np
from sklearn.linear_model import LassoCV


## data path needs to point to where the datasets
## are that were created by the R cross validation script
## code-section-6A.R
data_path   = ""

num_reps   = 10

k_vec      = [10]
nvars_vec  = [50, 100, 250, 500]
nobs_vec   = [1e5, 1e6]
nobs_e_vec = [5, 6]



result_time_vec = np.ndarray((len(nobs_vec) * len(nvars_vec) * len(k_vec), 4), float)

ct = 0

for idx_nvars, nvars in enumerate(nvars_vec):
    for idx_nobs, nobs in enumerate(nobs_vec):
        nobs_e = nobs_e_vec[idx_nobs]

        ## read in data
        x_train = np.genfromtxt(data_path + 'x_nvars'   + str(nvars) + '_nobs1e+0' + str(nobs_e) + '.csv', delimiter=',')
        y_train = np.genfromtxt(data_path + 'y_nvars'   + str(nvars) + '_nobs1e+0' + str(nobs_e) + '.csv', delimiter=',')
        lam     = np.genfromtxt(data_path + 'lam_nvars' + str(nvars) + '_nobs1e+0' + str(nobs_e) + '.csv', delimiter=',')
        x_train = np.delete(x_train, (0), axis=0)
        y_train = np.delete(y_train, (0), axis=0)
        lam     = np.delete(lam, (0), axis=0)

        print("nvars")
        print(str(nvars))
        print("nobs")
        print(str(nobs))

        t_lasso_cv = 0
        for idx_k, k in enumerate(k_vec):
            print(str(k))
            for rep in range(0, num_reps):
                t1 = time.time()
                model = LassoCV(cv=k, alphas=lam).fit(x_train, y_train)
                t2 = time.time()
                t_lasso_cv = t_lasso_cv + t2 - t1
                print("time = " + str(t2 - t1))

            t_lasso_cv = t_lasso_cv / num_reps
            print("ave time = " + str(t_lasso_cv))
            result_time_vec[ct,0] = nobs
            result_time_vec[ct,1] = nvars
            result_time_vec[ct,2] = k
            result_time_vec[ct,3] = t_lasso_cv
            ct = ct + 1


            np.savetxt(data_path + 'sklearn_lasso_times_10fold.csv', result_time_vec, delimiter=',')
