Mr.ASH penalized regression trendfiltering demo
- About
- Fit data with Mr.Ash.Pen
- Failures
- Trendfiltering with less sparsity (p_causal = 20)
- Failures
- Example of higher order trendfiltering
import numpy as np
from mrashpen.inference.penalized_regression import PenalizedRegression as MrASHPen
from mrashpen.models.normal_means_ash import NormalMeansASH
import sys
sys.path.append('/home/saikat/Documents/work/sparse-regression/simulation/eb-linreg-dsc/dsc/functions')
import simulate
import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black')
Simulate trendfiltering data with basis k = 0
n = 100
p = 100
p_causal = 4
snr = 20
k = 5
X, y, Xtest, ytest, beta, se = simulate.changepoint_predictors (n, p, p_causal, snr,
k = 0, signal = 'gamma', seed = 20)
residual_var = se * se
First, I use fixed priors $w_k$ and fixed residual variance.
def plot_trendfilter_mrashpen(X, y, beta, ytest, res):
n = y.shape[0]
p = X.shape[1]
bhat = res.coef
ypred = np.dot(X, bhat)
fig = plt.figure(figsize = (12, 6))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
ax1.scatter(np.arange(n), ytest, edgecolor = 'black', facecolor='white', label="ytest")
ax1.plot(np.arange(n), ypred, label="Mr.Ash.Pen")
ax1.legend()
ax1.set_xlabel("Sample index")
ax1.set_ylabel("y")
ax2.scatter(np.arange(p), beta, edgecolor = 'black', facecolor = 'white', label = "btrue")
ax2.scatter(np.arange(p), bhat, s = 10, color = 'firebrick', label = "Mr.Ash.Pen")
ax2.legend()
ax2.set_xlabel("Sample index")
ax2.set_ylabel("b")
plt.tight_layout()
plt.show()
## Prior
w = np.zeros(k)
w[0] = 0
w[1:(k-1)] = np.repeat((1 - w[0])/(k-1), (k - 2))
w[k-1] = 1 - np.sum(w)
prior_grid = np.arange(k) * 10
print("Initial w: ", w)
print("Prior_grid: ", prior_grid)
print()
b0 = np.zeros(p)
plr_lbfgs = MrASHPen(optimize_w = False, debug = False)
plr_lbfgs.fit(X, y, prior_grid, binit = b0, winit = w, s2init = residual_var)
plot_trendfilter_mrashpen(X, y, beta, ytest, plr_lbfgs)
Then, we try to optimize $w_k$ using the penalty function $h(b, w)$.
prior_grid = np.arange(k) * 10 + 1
print("Using w: ", w)
print("Using prior_grid: ", prior_grid)
print()
b0 = np.zeros(p)
plr_lbfgs = MrASHPen(optimize_w = True, debug = False)
plr_lbfgs.fit(X, y, prior_grid, binit = b0, winit = w, s2init = residual_var)
plot_trendfilter_mrashpen(X, y, beta, ytest, plr_lbfgs)
print("Optimized w: ", plr_lbfgs.prior)
- If the prior grid has a zero component (see below), then there is a "division-by-zero" error after a few updates. In the above example, I cheated using non-zero grid values.
- To Do. Check if Definition 3.1 in
mixsqp
paper is applicable to $h(b, w)$.
This has now been fixed.
prior_grid = np.arange(k) * 10
print("Using w: ", w)
print("Using prior_grid: ", prior_grid)
print()
b0 = np.zeros(p)
plr_lbfgs = MrASHPen(optimize_w = False, debug = False)
plr_lbfgs.fit(X, y, prior_grid, binit = b0, winit = w, s2init = residual_var)
plr_lbfgs2 = MrASHPen(optimize_w = True, debug = False)
plr_lbfgs2.fit(X, y, prior_grid, binit = plr_lbfgs.coef, winit = w, s2init = residual_var)
plot_trendfilter_mrashpen(X, y, beta, ytest, plr_lbfgs2)
print("Optimized w: ", plr_lbfgs2.prior)
prior_grid = np.arange(k) * 10
print("Using w: ", w)
print("Using prior_grid: ", prior_grid)
print()
b0 = np.zeros(p)
plr_lbfgs = MrASHPen(optimize_w = True, debug = False)
plr_lbfgs.ebfit(X, y, prior_grid, binit = b0, winit = w, s2init = residual_var)
plot_trendfilter_mrashpen(X, y, beta, ytest, plr_lbfgs)
print("Optimized w: ", plr_lbfgs.prior)
Use a relaxed prior grid
prior_grid = np.arange(k)
print("Using w: ", w)
print("Using prior_grid: ", prior_grid)
print()
b0 = np.zeros(p)
plr_lbfgs = MrASHPen(optimize_w = True, debug = False)
plr_lbfgs.ebfit(X, y, prior_grid, binit = b0, winit = w, s2init = residual_var)
plot_trendfilter_mrashpen(X, y, beta, ytest, plr_lbfgs)
print("Optimized w: ", plr_lbfgs.prior)
What happens if I use higher order basis for fitting?
prior_grid = np.arange(k) * 10
print("Using w: ", w)
print("Using prior_grid: ", prior_grid)
print()
b0 = np.zeros(p)
X_alt = simulate.trend_filtering_basis(n, p, 1)
plr_lbfgs = MrASHPen(optimize_w = True, debug = False)
plr_lbfgs.ebfit(X_alt, y, prior_grid, binit = b0, winit = w, s2init = residual_var)
plot_trendfilter_mrashpen(X_alt, y, beta, ytest, plr_lbfgs)
print("Optimized w: ", plr_lbfgs.prior)
n = 100
p = 100
p_causal = 20
snr = 20
k = 5
X, y, Xtest, ytest, beta, se = simulate.changepoint_predictors (n, p, p_causal, snr,
k = 0, signal = 'gamma', seed = 20)
residual_var = se * se
## Prior
w = np.zeros(k)
w[0] = 0
w[1:(k-1)] = np.repeat((1 - w[0])/(k-1), (k - 2))
w[k-1] = 1 - np.sum(w)
prior_grid = np.arange(k) * 10
print("Initial w: ", w)
print("Prior_grid: ", prior_grid)
print()
b0 = np.zeros(p)
plr_lbfgs = MrASHPen(optimize_w = True, debug = False)
plr_lbfgs.fit(X, y, prior_grid, binit = b0, winit = w, s2init = residual_var)
plot_trendfilter_mrashpen(X, y, beta, ytest, plr_lbfgs)
print("Optimized w: ", plr_lbfgs.prior)
n = 100
p = 100
p_causal = 4
snr = 20
k = 5
X, y, Xtest, ytest, beta, se = simulate.changepoint_predictors (n, p, p_causal, snr,
k = 0, signal = 'gamma', seed = 20)
residual_var = se * se
## Prior
w = np.zeros(k)
w[0] = 0.01
w[1:(k-1)] = np.repeat((1 - w[0])/(k-1), (k - 2))
w[k-1] = 1 - np.sum(w)
prior_grid = np.arange(k) * 10
print("Initial w: ", w)
print("Prior_grid: ", prior_grid)
print()
b0 = np.zeros(p)
plr_lbfgs = MrASHPen(optimize_w = True, debug = False)
plr_lbfgs.fit(X, y, prior_grid, binit = b0, winit = w, s2init = residual_var)
plot_trendfilter_mrashpen(X, y, beta, ytest, plr_lbfgs)
print("Optimized w: ", plr_lbfgs.prior)
n = 100
p = 100
p_causal = 4
snr = 20
k = 5
X, y, Xtest, ytest, beta, se = simulate.changepoint_predictors (n, p, p_causal, snr,
k = 0, signal = 'gamma', seed = 20)
residual_var = se * se
## Prior
w = np.zeros(k)
w[0] = 0.01
w[1:(k-1)] = np.repeat((1 - w[0])/(k-1), (k - 2))
w[k-1] = 1 - np.sum(w)
prior_grid = np.arange(k) * 10
print("Initial w: ", w)
print("Prior_grid: ", prior_grid)
print()
b0 = np.zeros(p)
plr_lbfgs = MrASHPen(optimize_w = True, debug = False, witer = 0)
plr_lbfgs.fit(X, y, prior_grid, binit = b0, winit = w, s2init = residual_var)
plot_trendfilter_mrashpen(X, y, beta, ytest, plr_lbfgs)
print("Optimized w: ", plr_lbfgs.prior)
n = 100
p = 100
p_causal = 4
snr = 20
k = 5
X, y, Xtest, ytest, beta, se = simulate.changepoint_predictors (n, p, p_causal, snr,
k = 1, signal = 'gamma', seed = 10)
residual_var = se * se
## Prior
w = np.zeros(k)
w[0] = 0
w[1:(k-1)] = np.repeat((1 - w[0])/(k-1), (k - 2))
w[k-1] = 1 - np.sum(w)
prior_grid = np.arange(k) * 1
print("Initial w: ", w)
print("Prior_grid: ", prior_grid)
print()
b0 = np.zeros(p)
plr_lbfgs = MrASHPen(optimize_w = True, debug = False)
plr_lbfgs.fit(X, y, prior_grid, binit = b0, winit = w, s2init = residual_var)
plot_trendfilter_mrashpen(X, y, beta, ytest, plr_lbfgs)
print("Optimized w: ", plr_lbfgs.prior)