import numpy as np
import pandas as pd
import os
import subprocess
import tempfile

from mrashpen.inference.penalized_regression import PenalizedRegression as PLR
from mrashpen.inference.mrash_wrapR          import MrASHR
from mrashpen.models.plr_ash                 import PenalizedMrASH
from mrashpen.models.normal_means_ash_scaled import NormalMeansASHScaled
from mrashpen.inference.ebfit                import ebfit
from mrashpen.utils                          import R_utils
from mrashpen.utils                          import R_lasso

import sys
sys.path.append('/home/saikat/Documents/work/sparse-regression/simulation/eb-linreg-dsc/dsc/functions')
import simulate

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black')


def center_and_scale(Z):
    dim = Z.ndim
    if dim == 1:
        Znew = Z / np.std(Z)
        Znew = Znew - np.mean(Znew)
    elif dim == 2:
        Znew = Z / np.std(Z, axis = 0)
        Znew = Znew - np.mean(Znew, axis = 0).reshape(1, -1)
    return Znew

def initialize_ash_prior(k, scale = 2):
    w = np.zeros(k)
    w[0] = 0.0
    w[1:(k-1)] = np.repeat((1 - w[0])/(k-1), (k - 2))
    w[k-1] = 1 - np.sum(w)
    sk2 = np.square((np.power(scale, np.arange(k) / k) - 1))
    prior_grid = np.sqrt(sk2)
    return w, prior_grid

def plot_linear_mrashpen(X, y, Xtest, ytest, btrue, strue, bhat, 
                         intercept = 0, title = None):
    ypred = np.dot(Xtest, bhat) + intercept
    fig = plt.figure(figsize = (12, 6))
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    ax1.scatter(ytest, ypred, s = 2, alpha = 0.5)
    mpl_utils.plot_diag(ax1)
    ax2.scatter(btrue, bhat)
    mpl_utils.plot_diag(ax2)

    ax1.set_xlabel("Y_test")
    ax1.set_ylabel("Y_predicted")
    ax2.set_xlabel("True b")
    ax2.set_ylabel("Predicted b")
    if title is not None:
        fig.suptitle(title)
    plt.tight_layout()
    plt.show()
    
    
def plot_convergence(objs, methods, nwarm, eps = 1e-8):
    fig = plt.figure(figsize = (12, 6))
    ax1 = fig.add_subplot(111)

    objmin  = np.min([np.min(x) for x in objs])

    for obj, method, iteq in zip(objs, methods, nwarm):
        m_obj = obj[iteq:] - objmin
        m_obj = m_obj[m_obj > 0]
        ax1.plot(range(iteq, len(m_obj) + iteq), np.log10(m_obj), label = method)
    ax1.legend()

    ax1.set_xlabel("Number of Iterations")
    ax1.set_ylabel("log( max(ELBO) - ELBO )")

    plt.show()
    return

def plot_trendfilter_mrashpen(X, y, beta, ytest, bhat,
                              intercept = 0, title = None):
    n = y.shape[0]
    p = X.shape[1]

    ypred = np.dot(X, bhat) + intercept
    fig = plt.figure(figsize = (12, 6))
    ax1 = fig.add_subplot(121)
    ax2 = fig.add_subplot(122)
    ax1.scatter(np.arange(n), ytest, edgecolor = 'black', facecolor='white')
    ax1.plot(np.arange(n), ypred)
    ax1.set_xlabel("Sample index")
    ax1.set_ylabel("Y")

    ax2.scatter(np.arange(p), beta, edgecolor = 'black', facecolor = 'white')
    ax2.scatter(np.arange(p), bhat, s = 40, color = 'firebrick')
    ax2.set_xlabel("Sample index")
    ax2.set_ylabel("b")
    
    if title is not None:
        fig.suptitle(title)

    plt.tight_layout()
    plt.show()
    
def linreg_summary_df(sigma2, objs, methods):
    data     = [[strue * strue,  '-', '-']]
    rownames = ['True']
    for obj, method in zip(objs, methods):
        data.append([obj.residual_var, obj.elbo_path[-1], obj.niter])
        rownames.append(method)
    colnames = ['sigma2', 'ELBO', 'niter']
    df = pd.DataFrame.from_records(data, columns = colnames, index = rownames)
    return df
n = 200
p = 2000
p_causal = 50
pve = 0.95
k = 20

X, y, Xtest, ytest, btrue, strue = simulate.equicorr_predictors(n, p, p_causal, pve, rho = 0.0, seed = 100)
X      = center_and_scale(X)
Xtest  = center_and_scale(Xtest)
wk, sk = initialize_ash_prior(k, scale = 2)
def r_lasso(X, y, nfolds = 10, alpha = 1.0):
    rscript_file = "utils/fit_lasso.R"
    os_handle, data_rds_file = tempfile.mkstemp(suffix = ".rds")
    datadict = {'X': X, 'y': y}
    R_utils.save_rds(datadict, data_rds_file)
    os_handle, out_rds_file = tempfile.mkstemp(suffix = ".rds")
    cmd  = ["Rscript",   rscript_file]
    cmd += ["--outfile", out_rds_file]
    cmd += ["--infile",  data_rds_file]
    cmd += ["--nfolds", f"{nfolds}"]
    cmd += ["--alpha",  f"{alpha}"]


    process = subprocess.Popen(cmd,
                               stdout = subprocess.PIPE,
                               stderr = subprocess.PIPE
                              )
    res     = process.communicate()
    if len(res[0].decode('utf-8')) > 0:
        print(res[0].decode('utf-8'))
    if len(res[1].decode('utf-8')) > 0:
        print("ERROR ==>")
        print(res[1].decode('utf-8'))
    retcode  = process.returncode
    fit_dict = R_utils.load_rds(out_rds_file) if retcode == 0 else None
    if os.path.exists(data_rds_file): os.remove(data_rds_file)
    if os.path.exists(out_rds_file):  os.remove(out_rds_file)
    intercept = fit_dict['mu']
    coef = fit_dict['beta']
    return intercept, coef, fit_dict

def sklearn_lasso(X, y, nfolds = 10):
    from sklearn import linear_model
    lasso_cv = linear_model.LassoCV(cv = 10, max_iter = 10000)
    lasso_cv.fit(X, y)
    clf = linear_model.Lasso(alpha=lasso_cv.alpha_)
    clf.fit(X, y)
    return clf.intercept_, clf.coef_, clf
rlasso_a0, rlasso_b, rlasso_fit = R_lasso.fit(X, y)
R[write to console]: Loading required package: Matrix

sklasso_a0, sklasso_b, sklasso_fit = sklearn_lasso(X, y)
/home/saikat/.conda/envs/py39mkl/lib/python3.9/site-packages/sklearn/linear_model/_coordinate_descent.py:530: ConvergenceWarning: Objective did not converge. You might want to increase the number of iterations. Duality gap: 1.6615756547953096, tolerance: 0.8931956203813735
  model = cd_fast.enet_coordinate_descent(
plot_linear_mrashpen(X, y, Xtest, ytest, btrue, strue, 
                     rlasso_b, intercept = rlasso_a0, title = 'glmnet Lasso')

plot_linear_mrashpen(X, y, Xtest, ytest, btrue, strue, 
                     sklasso_b, intercept = sklasso_a0, title = 'sklearn Lasso')