Bayes Lasso using EBMR
import numpy as np
import pandas as pd
from scipy import linalg as sc_linalg
import matplotlib.pyplot as plt
import sys
sys.path.append("../../ebmrPy/")
from utils import log_density
from inference import f_elbo
from inference import penalized_em
from inference.ebmr import EBMR
import ipdb
sys.path.append("../../utils/")
import mpl_stylesheet
mpl_stylesheet.banskt_presentation(fontfamily = 'latex-clearsans', fontsize = 18, colors = 'banskt', dpi = 72)
def standardize(X):
Xnorm = (X - np.mean(X, axis = 0))
Xstd = Xnorm / np.sqrt((Xnorm * Xnorm).sum(axis = 0))
return Xstd
def lasso_data(nsample, nvar, neff, errsigma, sb2 = 100, seed=200):
np.random.seed(seed)
X = np.random.normal(0, 1, nsample * nvar).reshape(nsample, nvar)
X = standardize(X)
btrue = np.zeros(nvar)
bidx = np.random.choice(nvar, neff , replace = False)
btrue[bidx] = np.random.normal(0, np.sqrt(sb2), neff)
y = np.dot(X, btrue) + np.random.normal(0, errsigma, nsample)
y = y - np.mean(y)
#y = y / np.std(y)
return X, y, btrue
def lims_xy(ax):
lims = [
np.min([ax.get_xlim(), ax.get_ylim()]), # min of both axes
np.max([ax.get_xlim(), ax.get_ylim()]), # max of both axes
]
return lims
def plot_diag(ax):
lims = lims_xy(ax)
ax.plot(lims, lims, ls='dotted', color='gray')
n = 50
p = 100
peff = 10
sb2 = 100.0
sd = 2.0
X, y, btrue = lasso_data(n, p, peff, sd, sb2)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(np.dot(X,btrue), y)
plot_diag(ax1)
plt.show()
eblasso = EBMR(X, y, prior='dexp', grr='em', sigma='full', inverse='direct', max_iter = 1000, tol=1e-8)
ebridge = EBMR(X, y, prior='point', grr='em', sigma='full', inverse='direct', max_iter = 1000, tol=1e-8)
eblasso.update()
ebridge.update()
data = {'s2': [eblasso.s2, ebridge.s2],
'sb2': [eblasso.sb2, ebridge.sb2],
's2 x sb2': [eblasso.s2 * eblasso.sb2, ebridge.s2 * ebridge.sb2],
'ELBO': [eblasso.elbo, ebridge.elbo],
}
resdf = pd.DataFrame.from_dict(data)
resdf.index = ['dexp', 'point']
resdf.round(decimals=3)
eblasso.mll_path
eblasso.elbo_path
fig = plt.figure()
ax1 = fig.add_subplot(111)
ypred_lasso = np.dot(X, eblasso.mu)
ypred_ridge = np.dot(X, ebridge.mu)
ax1.scatter(y, ypred_lasso, color='salmon')
ax1.scatter(y, ypred_ridge, color='dodgerblue')
plot_diag(ax1)
plt.show()
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(np.arange(p), btrue, color = 'black', s = 10)
ax1.scatter(np.arange(p), eblasso.mu, color='salmon')
ax1.scatter(np.arange(p), ebridge.mu, color='dodgerblue')
plt.show()