Optmization of the penalty function using Lagrange multiplier
About
An interesting sub-problem for the penalized regression formulation of the Mr.ASH sparse multiple linear regression involves optimization of the penalty function $\rho$. The details of the theory are written in Overleaf. The numerical optimization of $\rho$ can be performed using:
- method of Lagrangian multipliers
- partial updates with ADMM
Here, we show a basic implementation of the method of Lagrangian multipliers.
import numpy as np
from scipy import optimize as sp_optimize
import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black')
from mrashpen.models.normal_means_ash_scaled import NormalMeansASHScaled
def sample_mixgauss(wk, sk, size):
runif = np.random.uniform(0, 1, size = size)
gcomp = np.digitize(runif, np.cumsum(wk))
x = np.zeros(size)
for i, gc in enumerate(gcomp):
if sk[gc] > 0: x[i] = np.random.normal(0, sk[gc])
return x
def NM_sample(mean, std = 1.0):
p = mean.shape[0]
cov = np.eye(p) * std * std
y = np.random.multivariate_normal(mean, cov)
return y
def initialize_ash_prior(k, scale = 2, sparsity = None):
w = np.zeros(k)
w[0] = 1 / k if sparsity is None else sparsity
w[1:(k-1)] = np.repeat((1 - w[0])/(k-1), (k - 2))
w[k-1] = 1 - np.sum(w)
sk2 = np.square((np.power(scale, np.arange(k) / k) - 1))
prior_grid = np.sqrt(sk2)
return w, prior_grid
We generate 500 samples from a normal means distribution whose mean is obtained from Mr.ASH prior and a given variance ( $= \sigma^2 v_j^2$). In a multiple regression setting,
$\mathbf{y} \mid \mathbf{X}, \mathbf{b}, \sigma \sim \mathcal{N}(\mathbf{y} \mid \mathbf{X}\mathbf{b}, \sigma^2\mathbb{I}$),
the NM samples are equivalent to the regression coefficients $\mathbf{b}$. Given $\sigma_k$ (the standard deviation of the mixture components in the Mr.ASH prior), $\sigma$ and $v_j^2 = (\mathbf{X}^{\mathsf{T}}\mathbf{X})^{-1}$, the target is to recover the mixture weights $w_k$ of the Mr.ASH prior and the corresponding $\theta$ such that $M_{w, \sigma}(\boldsymbol{\theta}) = b$, where $M_{w, \sigma}$ is the NM Posterior Means Operator.
We consider the minimization of the penalty function,
$\min_{w, \theta} \sum_{j=1}^{P} \frac{1}{\sigma^2 v_j^2} \rho_{w, \sigma} (M_{w, \sigma}(\theta_j))$ subject to $w_k \ge 0, \sum_{k} w_k = 1$ and $M_{w, \sigma}(\theta_j) = b_j$
Note: Ideally, $\sigma$ also needs to be obtained from the optimization but is assumed to be known for now. In this document, I have not yet implemented the derivatives of the objective function although it can be done as derived in the Overleaf notebook.
p = 500
k = 3
sparsity = 0.8
strue = 1.0
np.random.seed(100)
wtrue, sk = initialize_ash_prior(k, sparsity = sparsity)
btrue = sample_mixgauss(wtrue, sk, p)
y = NM_sample(btrue, std = strue)
dj = np.ones(p)
We use a softmax parametrization of $w_k$ to enforce the constraints on $w_k$ and use the method of Lagrange multipliers for the third constraint $M_{w, \sigma}(\theta_j) = b_j$.
sk_str = ", ".join([f"{x:.3f}" for x in sk])
print(f"Standard deviation of the mixture components:\n[{sk_str}]")
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.hist(btrue, density = True)
ax1.set_xlabel("b")
ax1.set_ylabel("Density")
plt.show()
def softmax(x, base = np.exp(1)):
if base is not None:
beta = np.log(base)
x = x * beta
e_x = np.exp(x - np.max(x))
return e_x / np.sum(e_x, axis = 0, keepdims = True)
def penalty_operator(z, wk, std, sk, dj):
nm = NormalMeansASHScaled(z, std, wk, sk, d = dj)
tvar = (std * std) / dj
lambdaj = - nm.logML - 0.5 * tvar * np.square(nm.logML_deriv)
return lambdaj
def shrinkage_operator(nm):
M = nm.y + nm.yvar * nm.logML_deriv
M_bgrad = 1 + nm.yvar * nm.logML_deriv2
M_wgrad = nm.yvar.reshape(-1, 1) * nm.logML_deriv_wderiv
M_s2grad = (nm.logML_deriv / nm._d) + (nm.yvar * nm.logML_deriv_s2deriv)
return M, M_bgrad, M_wgrad, M_s2grad
def unshrink_b(b, std, wk, sk, dj, theta = None, max_iter = 100, tol = 1e-8):
# this is the initial value of theta
if theta is None:
theta = np.zeros_like(b)
# Newton-Raphson iteration
for itr in range(max_iter):
nmash = NormalMeansASHScaled(theta, std, wk, sk, d = dj)
Mtheta, Mtheta_bgrad, _, _ = shrinkage_operator(nmash)
theta_new = theta - (Mtheta - b) / Mtheta_bgrad
diff = np.sum(np.square(theta_new - theta))
theta = theta_new
obj = np.sum(- nmash.logML - 0.5 * nmash.yvar * np.square(nmash.logML_deriv))
print(obj)
if diff <= tol:
break
return theta
def shrink_theta(z, std, wk, sk, dj):
nmash = NormalMeansASHScaled(z, std, wk, sk, d = dj)
Mb = shrinkage_operator(nmash)[0]
return Mb
def penalty_operator_lagrangian(z, wk, std, sk, dj, lgrng, b):
Mt = shrink_theta(z, std, wk, sk, dj)
hwt = penalty_operator(z, wk, std, sk, dj)
obj = np.sum(hwt) + np.sum(lgrng * (Mt - b))
return obj
def penalty_operator_lagrangian_deriv(z, wk, std, sk, dj, lgrng, b):
'''
The Normal Means model
'''
nmash = NormalMeansASHScaled(z, std, wk, sk, d = dj)
'''
gradient w.r.t lambda_j (lagrangian penalty)
'''
M, M_bgrad, M_wgrad, M_s2grad = shrinkage_operator(nmash)
dLdl = M - b
'''
gradient w.r.t wk (prior mixture coefficients)
'''
tvar = (strue * strue) / dj
v2_ld_ldwd = tvar.reshape(-1, 1) * nmash.logML_deriv.reshape(-1, 1) * nmash.logML_deriv_wderiv
## gradient of first term and second term of the lagrangian
l1_wgrad = - nmash.logML_wderiv - v2_ld_ldwd
l2_wgrad = lgrng.reshape(-1, 1) * M_wgrad
dLdw = np.sum(l1_wgrad + l2_wgrad, axis = 0)
'''
gradient w.r.t theta
'''
l1_tgrad = - nmash.logML_deriv - tvar * nmash.logML_deriv * nmash.logML_deriv2
l2_tgrad = lgrng * (1 + tvar * nmash.logML_deriv2)
dLdt = l1_tgrad + l2_tgrad
return dLdl, dLdw, dLdt
def objective_numeric_lagrangian(params, std, sk, dj, b, p, k, softmax_base):
zj = params[:p]
lj = params[p:2*p]
ak = params[2*p:]
wk = softmax(ak, base = softmax_base)
dLdl, dLdw, dLdt = penalty_operator_lagrangian_deriv(zj, wk, strue, sk, dj, lj, btrue)
akjac = np.log(softmax_base) * wk.reshape(-1, 1) * (np.eye(k) - wk)
dLda = np.sum(dLdw * akjac, axis = 1)
obj = np.sqrt(np.sum(np.square(dLdl)) + np.sum(np.square(dLda)) + np.sum(np.square(dLdt)))
return obj
We initialize $\boldsymbol{\theta}$ from a random distribution, initialize $w_k = 1 / K$ for all $k$ and initialize the Lagrange multipliers $\lambda_j = 2$ for all $j$.
z = btrue.copy()
softmax_base = np.exp(1)
winit, _ = initialize_ash_prior(k)
akinit = np.log(winit + 1e-8) / np.log(softmax_base)
lgrng = np.ones(btrue.shape[0]) * 2.0
wk = softmax(akinit, base = softmax_base)
The optimization is carried out using Scipy default optimization routine with conjugate gradient descent and numerical derivatives.
initparams = np.concatenate([z, lgrng, akinit])
cg_min = sp_optimize.minimize(objective_numeric_lagrangian, initparams,
args = (strue, sk, dj, btrue, p, k, softmax_base),
method = 'CG',
options = {'disp': True, 'maxiter': 50, 'return_all': True}
)
We compare the true mixture coefficients $w_k$ with the estimates $\hat{w}_k$. We estimate the coefficients $\hat{b}_j = M_{\hat{w}, \sigma}(\hat{\theta}_j)$ from the estimates of $\hat{\theta}_j$.
z_cg = cg_min.x[:p]
a_cg = cg_min.x[2*p:]
w_cg = softmax(a_cg, base = softmax_base)
lj_cg = cg_min.x[p:2*p]
print ("True w:", wtrue)
print("Est. w: ", w_cg)
print("Est. Lagrangian (for non-zero b): ")
print(lj_cg[btrue!=0])
b_cg = shrink_theta(z_cg, strue, w_cg, sk, dj)
fig = plt.figure(figsize = (12,6))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
ax1.scatter(btrue, b_cg)
mpl_utils.plot_diag(ax1)
ax1.set_xlabel("True b")
ax1.set_ylabel("Estimated b")
niter = cg_min.nit
allobjs = np.zeros(niter + 1)
for i, params in enumerate(cg_min.allvecs):
z_it = params[:p]
a_it = params[2*p:]
lj_it = params[p:2*p]
w_it = softmax(a_it, base = softmax_base)
#alllagrangian[i] = penalty_operator_lagrangian(z_it, w_it, strue, sk, dj, lj_it, btrue)
allobjs[i] = objective_numeric_lagrangian(params, strue, sk, dj, btrue, p, k, softmax_base)
ax2.plot(np.arange(niter + 1), allobjs)
ax2.set_xlabel("# iterations")
ax2.set_ylabel(r"h(w, $\theta$)")
plt.tight_layout()
plt.show()
Here, I show the evolution of $w_k$ and estimated $\mathbf{b}$ over the iterations.
subplot_h = 1.8
nstep = 5
nplot = int(niter / nstep) + 1
ncol = 4
nrow = int(nplot / ncol + 1) if nplot%ncol != 0 else int(nplot / ncol)
figw = ncol * subplot_h + (ncol - 1) * 0.3 + 1.2
figh = nrow * subplot_h + (nrow - 1) * 0.3 + 1.5
figscale = 12.0 / figw
bgcolor = '#F0F0F0'
highlight_color = '#EE6868'
subdue_color = '#848f94'
text_color = '#69767c'
fig = plt.figure(figsize = (figw * figscale, figh * figscale))
axmain = fig.add_subplot(111)
for i in range(nplot):
ax = fig.add_subplot(nrow, ncol, i + 1)
itr = i * nstep
params = cg_min.allvecs[itr]
z_it = params[:p]
a_it = params[2*p:]
#lj_it = params[p:2*p]
w_it = softmax(a_it, base = softmax_base)
b_it = shrink_theta(z_it, strue, w_it, sk, dj)
ax.scatter(btrue, b_it, s=10)
mpl_utils.plot_diag(ax)
wtext = r'$w_k$ = ' + ', '.join([f"{w:.2f}" for w in w_it])
itrtext = f"Iteration {itr}"
ax.text(0.05, 0.85, wtext, va='top', ha='left',
transform=ax.transAxes, color = text_color, fontsize = 10)
ax.text(0.05, 0.95, itrtext, va='top', ha='left',
transform=ax.transAxes, color = text_color, fontsize = 10)
ax.tick_params(bottom = False, top = False, left = False, right = False,
labelbottom = False, labeltop = False, labelleft = False, labelright = False)
ax.set_facecolor(bgcolor)
for side, border in ax.spines.items():
border.set_visible(False)
if i < ncol:
ax.tick_params(top = True, labeltop = True, color = bgcolor, width = 5)
#ax.set_xticks(np.log10([0.001, 0.01, 0.1, 1.0]))
if i%ncol == 0:
ax.tick_params(left = True, labelleft = True, color = bgcolor, width = 5)
#ax.set_ylim(-0.1, 2.1)
axmain.tick_params(bottom = False, top = False, left = False, right = False,
labelbottom = False, labeltop = False, labelleft = False, labelright = False)
for side, border in axmain.spines.items():
border.set_visible(False)
axmain.set_ylabel(r'Estimated b', labelpad = 40, color = text_color)
axmain.set_xlabel(r'True b', labelpad = 50, color = text_color)
axmain.xaxis.set_label_position('top')
plt.tight_layout()
# plt.savefig(f'../plots/{fileprefix}.pdf', bbox_inches='tight')
# plt.savefig(f'../plots/{fileprefix}.png', bbox_inches='tight')
plt.show()
Why $w_k$ differ from true values?
We found that the estimated $\hat{w}_k$ differs from the true $w_k$. As shown below, this is because the objective function has a minimum at this value. This minimum can be avoided by initializing $w_k$ such that $w_1$ is above a certain threshold. However, while writing this, I still do not know the origin of this minimum.
In the following, I calculate the objective function $h(\theta, a)$ with fixed $\boldsymbol{\theta}$ and fixed $w_3 = 0.001$, while we vary $w_1$ and $w_2$ such that $w_1 + w_2 + w_3 = 1$. In the plot below, I show that the local minimum appears irrespective of the four different choices of $\boldsymbol{\theta}$ used for the plot (see figure legend). The "true" value of $\boldsymbol{\theta}$ is obtained from $M_{w, \sigma}^{-1}(\mathbf{b})$ using Newton-Raphson inversion.
z_newraph = unshrink_b(btrue, strue, wtrue, sk, dj)
def get_obj_list(aseq, z, std, sk, dj, b, p, k, softmax_base):
h_seq = list()
for ak in aseq:
params = np.concatenate([z, lgrng, ak])
h = objective_numeric_lagrangian(params, strue, sk, dj, btrue, p, k, softmax_base)
h_seq.append(h)
return np.array(h_seq)
def get_wk_text(wk):
wstr = ", ".join([f"{w:.2f}" for w in wk[:2]])
return f"({wstr})"
nseq = 101
w3 = 0.001
w1seq = np.linspace(0, 1 - w3, nseq)
w2seq = 1 - w3 - w1seq
w_seq = [np.array([w1, w2, w3]) for w1, w2 in zip(w1seq, w2seq)]
a_seq = [np.log(w + 1e-8) / np.log(softmax_base) for w in w_seq]
w_seq = [softmax(ak, base = softmax_base) for ak in a_seq]
h_seq_random = get_obj_list(a_seq, np.random.rand(p), strue, sk, dj, btrue, p, k, softmax_base)
h_seq_zeros = get_obj_list(a_seq, np.zeros(p), strue, sk, dj, btrue, p, k, softmax_base)
h_seq_btrue = get_obj_list(a_seq, btrue, strue, sk, dj, btrue, p, k, softmax_base)
h_seq_ttrue = get_obj_list(a_seq, z_newraph, strue, sk, dj, btrue, p, k, softmax_base)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(np.arange(nseq), h_seq_random, label = r"Random")
ax1.plot(np.arange(nseq), h_seq_zeros, label = r"Zeros")
ax1.plot(np.arange(nseq), h_seq_btrue, label = r"$\theta = b$")
ax1.plot(np.arange(nseq), h_seq_ttrue, label = r"True ($\theta = M^{-1}(b)$)")
legendtitle = r"Choice of $\theta$"
legend1 = ax1.legend(loc = 'upper left', bbox_to_anchor = (0.02, 0.98), frameon = False, title = legendtitle)
legend1._legend_box.align = "left"
#lframe = legend1.get_frame()
#lframe.set_linewidth(0)
#ax1.hist(mean)
#ax1.legend()
ax1.set_ylabel(r"$h(\theta, a)$")
ax1.set_xlabel(r"($w_1, w_2$) with $w_3 =$ " + f"{w3:.4f}")
xtickpos = np.arange(0, nseq, 10)
xticklabels = [get_wk_text(w_seq[i]) for i in xtickpos]
ax1.set_xticks(xtickpos)
ax1.set_xticklabels(xticklabels, rotation = 90)
plt.show()
#print ("Using Newton-Raphson inversion to obtain initial 𝜃 from true w and b")
#z = unshrink_b(btrue, strue, wtrue, sk, dj)
z = btrue.copy()
softmax_base = np.exp(1)
winit, _ = initialize_ash_prior(k, sparsity = 0.7)
akinit = np.log(winit + 1e-8) / np.log(softmax_base)
lgrng = np.ones(btrue.shape[0]) * 1.0
wk = softmax(akinit, base = softmax_base)
initparams = np.concatenate([z, lgrng, akinit])
cg_min = sp_optimize.minimize(objective_numeric_lagrangian, initparams,
args = (strue, sk, dj, btrue, p, k, softmax_base),
method = 'CG',
options = {'disp': True, 'maxiter': 500, 'return_all': True}
)
z_cg = cg_min.x[:p]
a_cg = cg_min.x[2*p:]
w_cg = softmax(a_cg, base = softmax_base)
lj_cg = cg_min.x[p:2*p]
print ("True w:", wtrue)
print("Est. w: ", w_cg)
b_cg = shrink_theta(z_cg, strue, w_cg, sk, dj)
fig = plt.figure(figsize = (12,6))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
ax1.scatter(btrue, b_cg)
mpl_utils.plot_diag(ax1)
ax1.set_xlabel("True b")
ax1.set_ylabel("Estimated b")
niter = cg_min.nit
allobjs = np.zeros(niter + 1)
for i, params in enumerate(cg_min.allvecs):
z_it = params[:p]
a_it = params[2*p:]
lj_it = params[p:2*p]
w_it = softmax(a_it, base = softmax_base)
#alllagrangian[i] = penalty_operator_lagrangian(z_it, w_it, strue, sk, dj, lj_it, btrue)
allobjs[i] = objective_numeric_lagrangian(params, strue, sk, dj, btrue, p, k, softmax_base)
ax2.plot(np.arange(niter + 1), allobjs)
ax2.set_xlabel("# iterations")
ax2.set_ylabel(r"h(w, $\theta$)")
plt.tight_layout()
plt.show()
# ========================
# Plot progress of optimization
# ========================
subplot_h = 1.8
niter = 200
nstep = 20
nplot = int(niter / nstep) + 1
ncol = 4
nrow = int(nplot / ncol + 1) if nplot%ncol != 0 else int(nplot / ncol)
figw = ncol * subplot_h + (ncol - 1) * 0.3 + 1.2
figh = nrow * subplot_h + (nrow - 1) * 0.3 + 1.5
figscale = 12.0 / figw
bgcolor = '#F0F0F0'
highlight_color = '#EE6868'
subdue_color = '#848f94'
text_color = '#69767c'
fig = plt.figure(figsize = (figw * figscale, figh * figscale))
axmain = fig.add_subplot(111)
for i in range(nplot):
ax = fig.add_subplot(nrow, ncol, i + 1)
itr = i * nstep
params = cg_min.allvecs[itr]
z_it = params[:p]
a_it = params[2*p:]
#lj_it = params[p:2*p]
w_it = softmax(a_it, base = softmax_base)
b_it = shrink_theta(z_it, strue, w_it, sk, dj)
ax.scatter(btrue, b_it, s=10)
mpl_utils.plot_diag(ax)
wtext = r'$w_k$ = ' + ', '.join([f"{w:.2f}" for w in w_it])
itrtext = f"Iteration {itr}"
ax.text(0.05, 0.85, wtext, va='top', ha='left',
transform=ax.transAxes, color = text_color, fontsize = 10)
ax.text(0.05, 0.95, itrtext, va='top', ha='left',
transform=ax.transAxes, color = text_color, fontsize = 10)
ax.tick_params(bottom = False, top = False, left = False, right = False,
labelbottom = False, labeltop = False, labelleft = False, labelright = False)
ax.set_facecolor(bgcolor)
for side, border in ax.spines.items():
border.set_visible(False)
if i < ncol:
ax.tick_params(top = True, labeltop = True, color = bgcolor, width = 5)
#ax.set_xticks(np.log10([0.001, 0.01, 0.1, 1.0]))
if i%ncol == 0:
ax.tick_params(left = True, labelleft = True, color = bgcolor, width = 5)
#ax.set_ylim(-0.1, 2.1)
axmain.tick_params(bottom = False, top = False, left = False, right = False,
labelbottom = False, labeltop = False, labelleft = False, labelright = False)
for side, border in axmain.spines.items():
border.set_visible(False)
axmain.set_ylabel(r'Estimated b', labelpad = 40, color = text_color)
axmain.set_xlabel(r'True b', labelpad = 50, color = text_color)
axmain.xaxis.set_label_position('top')
plt.tight_layout()
# plt.savefig(f'../plots/{fileprefix}.pdf', bbox_inches='tight')
# plt.savefig(f'../plots/{fileprefix}.png', bbox_inches='tight')
plt.show()