In this demo, I illustrate the use of mr-ash-pen package for sparse multiple linear regression and trendfiltering. Coefficients are initialized at zero. The residual variance is initialized at 1. The prior initialization is a bit tricky. Currently, I am using the same prior initializations (20 components, as described in Kim et. al.), but different initializations give different results.

import numpy as np
from mrashpen.inference.penalized_regression import PenalizedRegression as PLR
from mrashpen.models.normal_means_ash import NormalMeansASH

import sys
sys.path.append('/home/saikat/Documents/work/sparse-regression/simulation/eb-linreg-dsc/dsc/functions')
import simulate

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black')

def center_and_scale(Z):
dim = Z.ndim
if dim == 1:
Znew = Z / np.std(Z)
Znew = Znew - np.mean(Znew)
elif dim == 2:
Znew = Z / np.std(Z, axis = 0)
Znew = Znew - np.mean(Znew, axis = 0).reshape(1, -1)
return Znew

def initialize_ash_prior(k, scale = 2):
w = np.zeros(k)
w = 1e-8
w[1:(k-1)] = np.repeat((1 - w)/(k-1), (k - 2))
w[k-1] = 1 - np.sum(w)
sk2 = np.square((np.power(scale, np.arange(k) / k) - 1))
prior_grid = np.sqrt(sk2)
return w, prior_grid


### Linear model with n > p

def plot_linear_mrashpen(X, y, Xtest, ytest, btrue, strue, res):
bhat = res.coef
ypred = np.dot(Xtest, bhat)
fig = plt.figure(figsize = (12, 6))
ax1.scatter(ytest, ypred, s = 2, alpha = 0.5)
mpl_utils.plot_diag(ax1)
ax2.scatter(btrue, bhat)
mpl_utils.plot_diag(ax2)

ax1.set_xlabel("Y_test")
ax1.set_ylabel("Y_predicted")
ax2.set_xlabel("True b")
ax2.set_ylabel("Predicted b")
plt.tight_layout()
plt.show()


### Generate data

n = 2000
p = 200
p_causal = 20
pve = 0.7
k = 20

X, y, Xtest, ytest, btrue, strue = simulate.equicorr_predictors(n, p, p_causal, pve, rho = 0.0, seed = 10)
X      = center_and_scale(X)
Xtest  = center_and_scale(Xtest)
wk, sk = initialize_ash_prior(k)

## Optimize
plr_lbfgs = PLR(method = 'L-BFGS-B', optimize_w = True, optimize_s = True, is_prior_scaled = True,
debug = False, display_progress = False)
plr_lbfgs.fit(X, y, sk, binit = None, winit = wk, s2init = 1)

## Plot
plot_linear_mrashpen(X, y, Xtest, ytest, btrue, strue, plr_lbfgs) ### Linear model with n < p

### Generate data

n = 200
p = 2000
p_causal = 20
pve = 0.7
k = 20

X, y, Xtest, ytest, btrue, strue = simulate.equicorr_predictors(n, p, p_causal, pve, rho = 0.0, seed = 20)
X      = center_and_scale(X)
Xtest  = center_and_scale(Xtest)
wk, sk = initialize_ash_prior(k)

## Optimize
plr_lbfgs = PLR(method = 'L-BFGS-B', optimize_w = True, optimize_s = True, is_prior_scaled = True,
debug = False, display_progress = False)
plr_lbfgs.fit(X, y, sk, binit = None, winit = wk, s2init = 1)

## Plot
plot_linear_mrashpen(X, y, Xtest, ytest, btrue, strue, plr_lbfgs) ### Linear model with n > p and correlated X

### Generate data

n = 2000
p = 200
p_causal = 20
pve = 0.7
k = 20

X, y, Xtest, ytest, btrue, strue = simulate.equicorr_predictors(n, p, p_causal, pve, rho = 0.5, seed = 10)
X      = center_and_scale(X)
Xtest  = center_and_scale(Xtest)
wk, sk = initialize_ash_prior(k)

## Optimize
plr_lbfgs = PLR(method = 'L-BFGS-B', optimize_w = True, optimize_s = True, is_prior_scaled = True,
debug = False, display_progress = False)
plr_lbfgs.fit(X, y, sk, binit = None, winit = wk, s2init = 1)

## Plot
plot_linear_mrashpen(X, y, Xtest, ytest, btrue, strue, plr_lbfgs)