About

A sanity check for the Bayes Lasso method using EBMR

import numpy as np
import pandas as pd
from scipy import linalg as sc_linalg
import matplotlib.pyplot as plt

import sys
sys.path.append("../../ebmrPy/")
from utils import log_density
from inference import f_elbo
from inference import penalized_em
from inference.ebmr import EBMR

import ipdb

sys.path.append("../../utils/")
import mpl_stylesheet
mpl_stylesheet.banskt_presentation(fontfamily = 'latex-clearsans', fontsize = 18, colors = 'banskt', dpi = 72)
def standardize(X):
    Xnorm = (X - np.mean(X, axis = 0))
    Xstd = Xnorm / np.sqrt((Xnorm * Xnorm).sum(axis = 0))
    return Xstd


def lasso_data(nsample, nvar, neff, errsigma, sb2 = 100, seed=200):
    np.random.seed(seed)
    X = np.random.normal(0, 1, nsample * nvar).reshape(nsample, nvar)
    X = standardize(X)
    btrue = np.zeros(nvar)
    bidx = np.random.choice(nvar, neff , replace = False)
    btrue[bidx] = np.random.normal(0, np.sqrt(sb2), neff)
    y = np.dot(X, btrue) + np.random.normal(0, errsigma, nsample)
    y = y - np.mean(y)
    #y = y / np.std(y)
    return X, y, btrue

def lims_xy(ax):
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]
    return lims

def plot_diag(ax):
    lims = lims_xy(ax)
    ax.plot(lims, lims, ls='dotted', color='gray')
n = 50
p = 100
peff = 10
sb2 = 100.0
sd = 2.0
X, y, btrue = lasso_data(n, p, peff, sd, sb2)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(np.dot(X,btrue), y)
plot_diag(ax1)
plt.show()
eblasso = EBMR(X, y, prior='dexp',  grr='em', sigma='full', inverse='direct', max_iter = 1000, tol=1e-8)
ebridge = EBMR(X, y, prior='point', grr='em', sigma='full', inverse='direct', max_iter = 1000, tol=1e-8)
2020-12-01 12:55:43,172 | inference.ebmr | DEBUG | EBMR using dexp prior, em grr, full b posterior variance, direct inversion
2020-12-01 12:55:43,173 | inference.ebmr | DEBUG | EBMR using point prior, em grr, full b posterior variance, direct inversion
eblasso.update()
ebridge.update()
data = {'s2': [eblasso.s2, ebridge.s2],
        'sb2': [eblasso.sb2, ebridge.sb2],
        's2 x sb2': [eblasso.s2 * eblasso.sb2, ebridge.s2 * ebridge.sb2],
        'ELBO': [eblasso.elbo, ebridge.elbo],
       }
resdf = pd.DataFrame.from_dict(data)
resdf.index = ['dexp', 'point']
resdf.round(decimals=3)
s2 sb2 s2 x sb2 ELBO
dexp 0.975 6.257 6.102 -158.132
point 4.243 1.092 4.633 -139.934
eblasso.mll_path
array([         -inf, -162.5211116 , -159.397836  , -158.49845882,
       -158.25413555, -158.17937266, -158.15315927, -158.14270437,
       -158.13802668, -158.1358067 , -158.13452536, -158.13384434,
       -158.1334242 , -158.13317589, -158.13305632, -158.13294788,
       -158.1328467 , -158.13275135, -158.13266116, -158.13257572,
       -158.13249471, -158.13241787, -158.13234497, -158.13227578,
       -158.13221009, -158.13214772])
eblasso.elbo_path
array([         -inf, -165.0350422 , -160.95177667, -159.17931935,
       -158.55105265, -158.31597417, -158.21862614, -158.17454958,
       -158.15315208, -158.1421157 , -158.13627118, -158.13311766,
       -158.13132908, -158.13045503, -158.13032989, -158.13007672,
       -158.12985772, -158.12969343, -158.1295773 , -158.12949848,
       -158.1294471 , -158.12941533, -158.12939738, -158.1293891 ,
       -158.12938758, -158.12939083])
fig = plt.figure()
ax1 = fig.add_subplot(111)
ypred_lasso = np.dot(X, eblasso.mu)
ypred_ridge = np.dot(X, ebridge.mu)
ax1.scatter(y, ypred_lasso, color='salmon')
ax1.scatter(y, ypred_ridge, color='dodgerblue')
plot_diag(ax1)
plt.show()
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(np.arange(p), btrue, color = 'black', s = 10)
ax1.scatter(np.arange(p), eblasso.mu, color='salmon')
ax1.scatter(np.arange(p), ebridge.mu, color='dodgerblue')
plt.show()