About

An interesting sub-problem for the penalized regression formulation of the Mr.ASH sparse multiple linear regression involves optimization of the penalty function $\rho$. The details of the theory are written in Overleaf. The numerical optimization of $\rho$ can be performed using:

  • method of Lagrangian multipliers
  • partial updates with ADMM

Here, we show a basic implementation of the method of Lagrangian multipliers.

import numpy as np
from scipy import optimize as sp_optimize

import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black')

from mrashpen.models.normal_means_ash_scaled import NormalMeansASHScaled

def sample_mixgauss(wk, sk, size):
    runif = np.random.uniform(0, 1, size = size)
    gcomp = np.digitize(runif, np.cumsum(wk))
    x = np.zeros(size)
    for i, gc in enumerate(gcomp):
        if sk[gc] > 0: x[i] = np.random.normal(0, sk[gc]) 
    return x

def NM_sample(mean, std = 1.0):
    p   = mean.shape[0]
    cov = np.eye(p) * std * std
    y   = np.random.multivariate_normal(mean, cov)
    return y

def initialize_ash_prior(k, scale = 2, sparsity = None):
    w = np.zeros(k)
    w[0] = 1 / k if sparsity is None else sparsity
    w[1:(k-1)] = np.repeat((1 - w[0])/(k-1), (k - 2))
    w[k-1] = 1 - np.sum(w)
    sk2 = np.square((np.power(scale, np.arange(k) / k) - 1))
    prior_grid = np.sqrt(sk2)
    return w, prior_grid

We generate 500 samples from a normal means distribution whose mean is obtained from Mr.ASH prior and a given variance ( $= \sigma^2 v_j^2$). In a multiple regression setting,

$\mathbf{y} \mid \mathbf{X}, \mathbf{b}, \sigma \sim \mathcal{N}(\mathbf{y} \mid \mathbf{X}\mathbf{b}, \sigma^2\mathbb{I}$),

the NM samples are equivalent to the regression coefficients $\mathbf{b}$. Given $\sigma_k$ (the standard deviation of the mixture components in the Mr.ASH prior), $\sigma$ and $v_j^2 = (\mathbf{X}^{\mathsf{T}}\mathbf{X})^{-1}$, the target is to recover the mixture weights $w_k$ of the Mr.ASH prior and the corresponding $\theta$ such that $M_{w, \sigma}(\boldsymbol{\theta}) = b$, where $M_{w, \sigma}$ is the NM Posterior Means Operator.

We consider the minimization of the penalty function,

$\min_{w, \theta} \sum_{j=1}^{P} \frac{1}{\sigma^2 v_j^2} \rho_{w, \sigma} (M_{w, \sigma}(\theta_j))$ subject to $w_k \ge 0, \sum_{k} w_k = 1$ and $M_{w, \sigma}(\theta_j) = b_j$

Note: Ideally, $\sigma$ also needs to be obtained from the optimization but is assumed to be known for now. In this document, I have not yet implemented the derivatives of the objective function although it can be done as derived in the Overleaf notebook.

p = 500
k = 3
sparsity = 0.8
strue = 1.0

np.random.seed(100)
wtrue, sk = initialize_ash_prior(k, sparsity = sparsity)
btrue = sample_mixgauss(wtrue, sk, p)
y = NM_sample(btrue, std = strue)
dj = np.ones(p)

We use a softmax parametrization of $w_k$ to enforce the constraints on $w_k$ and use the method of Lagrange multipliers for the third constraint $M_{w, \sigma}(\theta_j) = b_j$.

sk_str = ", ".join([f"{x:.3f}" for x in sk])
print(f"Standard deviation of the mixture components:\n[{sk_str}]")
Standard deviation of the mixture components:
[0.000, 0.260, 0.587]
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.hist(btrue, density = True)
ax1.set_xlabel("b")
ax1.set_ylabel("Density")
plt.show()

def softmax(x, base = np.exp(1)):
    if base is not None:
        beta = np.log(base)
        x = x * beta
    e_x = np.exp(x - np.max(x))
    return e_x / np.sum(e_x, axis = 0, keepdims = True)

def penalty_operator(z, wk, std, sk, dj):
    nm = NormalMeansASHScaled(z, std, wk, sk, d = dj)
    tvar = (std * std) / dj
    lambdaj = - nm.logML - 0.5 * tvar * np.square(nm.logML_deriv)
    return lambdaj

def shrinkage_operator(nm):
    M        = nm.y + nm.yvar * nm.logML_deriv
    M_bgrad  = 1       + nm.yvar * nm.logML_deriv2
    M_wgrad  = nm.yvar.reshape(-1, 1) * nm.logML_deriv_wderiv
    M_s2grad = (nm.logML_deriv / nm._d) + (nm.yvar * nm.logML_deriv_s2deriv)
    return M, M_bgrad, M_wgrad, M_s2grad

def unshrink_b(b, std, wk, sk, dj, theta = None, max_iter = 100, tol = 1e-8):
    # this is the initial value of theta
    if theta is None:
        theta = np.zeros_like(b)
    # Newton-Raphson iteration
    for itr in range(max_iter):
        nmash = NormalMeansASHScaled(theta, std, wk, sk, d = dj)
        Mtheta, Mtheta_bgrad, _, _ = shrinkage_operator(nmash)
        theta_new = theta - (Mtheta - b) / Mtheta_bgrad
        diff = np.sum(np.square(theta_new - theta))
        theta = theta_new
        obj = np.sum(- nmash.logML - 0.5 * nmash.yvar * np.square(nmash.logML_deriv))
        print(obj)
        if diff <= tol:
            break
    return theta

def shrink_theta(z, std, wk, sk, dj):
    nmash = NormalMeansASHScaled(z, std, wk, sk, d = dj)
    Mb = shrinkage_operator(nmash)[0]
    return Mb

def penalty_operator_lagrangian(z, wk, std, sk, dj, lgrng, b):
    Mt  = shrink_theta(z, std, wk, sk, dj)
    hwt = penalty_operator(z, wk, std, sk, dj)
    obj = np.sum(hwt) + np.sum(lgrng * (Mt - b))
    return obj

def penalty_operator_lagrangian_deriv(z, wk, std, sk, dj, lgrng, b):
    '''
    The Normal Means model
    '''
    nmash = NormalMeansASHScaled(z, std, wk, sk, d = dj)
    '''
    gradient w.r.t lambda_j (lagrangian penalty)
    '''
    M, M_bgrad, M_wgrad, M_s2grad  = shrinkage_operator(nmash)
    dLdl = M - b
    '''
    gradient w.r.t wk (prior mixture coefficients)
    '''
    tvar  = (strue * strue) / dj
    v2_ld_ldwd = tvar.reshape(-1, 1) * nmash.logML_deriv.reshape(-1, 1) * nmash.logML_deriv_wderiv
    ## gradient of first term and second term of the lagrangian
    l1_wgrad = - nmash.logML_wderiv - v2_ld_ldwd
    l2_wgrad = lgrng.reshape(-1, 1) * M_wgrad
    dLdw = np.sum(l1_wgrad + l2_wgrad, axis = 0)
    '''
    gradient w.r.t theta
    '''
    l1_tgrad = - nmash.logML_deriv  - tvar * nmash.logML_deriv * nmash.logML_deriv2
    l2_tgrad = lgrng * (1 + tvar * nmash.logML_deriv2)
    dLdt = l1_tgrad + l2_tgrad
    return dLdl, dLdw, dLdt

def objective_numeric_lagrangian(params, std, sk, dj, b, p, k, softmax_base):
    zj = params[:p]
    lj = params[p:2*p]
    ak = params[2*p:]
    wk = softmax(ak, base = softmax_base)
    dLdl, dLdw, dLdt = penalty_operator_lagrangian_deriv(zj, wk, strue, sk, dj, lj, btrue)
    akjac = np.log(softmax_base) * wk.reshape(-1, 1) * (np.eye(k) - wk)
    dLda = np.sum(dLdw * akjac, axis = 1)
    obj = np.sqrt(np.sum(np.square(dLdl)) + np.sum(np.square(dLda)) + np.sum(np.square(dLdt)))
    return obj

We initialize $\boldsymbol{\theta}$ from a random distribution, initialize $w_k = 1 / K$ for all $k$ and initialize the Lagrange multipliers $\lambda_j = 2$ for all $j$.

z = btrue.copy()
softmax_base = np.exp(1)
winit, _ = initialize_ash_prior(k)
akinit = np.log(winit + 1e-8) / np.log(softmax_base)
lgrng = np.ones(btrue.shape[0]) * 2.0
wk  = softmax(akinit, base = softmax_base)

The optimization is carried out using Scipy default optimization routine with conjugate gradient descent and numerical derivatives.

initparams = np.concatenate([z, lgrng, akinit])
cg_min = sp_optimize.minimize(objective_numeric_lagrangian, initparams,
                              args = (strue, sk, dj, btrue, p, k, softmax_base),
                              method = 'CG',
                              options = {'disp': True, 'maxiter': 50, 'return_all': True}
                             )
Warning: Maximum number of iterations has been exceeded.
         Current function value: 0.064317
         Iterations: 50
         Function evaluations: 141564
         Gradient evaluations: 141

We compare the true mixture coefficients $w_k$ with the estimates $\hat{w}_k$. We estimate the coefficients $\hat{b}_j = M_{\hat{w}, \sigma}(\hat{\theta}_j)$ from the estimates of $\hat{\theta}_j$.

z_cg = cg_min.x[:p]
a_cg = cg_min.x[2*p:]
w_cg = softmax(a_cg, base = softmax_base)
lj_cg = cg_min.x[p:2*p]
print ("True w:", wtrue)
print("Est. w: ", w_cg)
print("Est. Lagrangian (for non-zero b): ")
print(lj_cg[btrue!=0])
True w: [0.8 0.1 0.1]
Est. w:  [8.26656510e-10 9.99999999e-01 5.00086286e-23]
Est. Lagrangian (for non-zero b): 
[  4.65746338  -0.72228219   2.19581871 -20.74760139  -0.47981782
   4.08776542  -5.23040571   4.0732611    5.8192827   -2.85600382
   7.61223209  -0.64169434   1.55117538  -4.52189092   3.84529584
  -6.40911099   1.34263467   0.74601678   2.91163085   0.52401514
  -0.35633649  -1.64563195  17.79185439  -6.79795854   2.02029346
  10.27168356   1.71511957   7.41218778  -0.20495449   1.47799834
   9.2768989   -0.86696252   7.52486007  -1.60597406  20.01057258
  -0.62257956   2.22106016   1.53198102 -12.90503322   3.10295873
  -3.09145382  -3.95790707  -5.84327267   8.81561381  -9.78704609
   3.55815711   4.57717772 -11.92469756   6.09880133  -2.11818023
  -3.96198987  -6.20555407   0.21010724   2.150619     1.52215735
   0.12843162  -3.96418615   7.63602716   1.0666652   -6.64105081
  -2.8559933    6.61280159  -8.93686185  -0.51818036   8.10208724
   4.07699859   2.41526577   1.15972965  -3.98997814   3.57243994
   4.90284072   8.18978483   2.76462938  18.09971594  -4.87524623
  -0.26809257  -5.66512694  -4.84265603  -0.52373218   2.92534879
  11.59707464   3.56847503   0.41328906  -4.53537178  -3.66957321
  -4.61814849   8.46638799   0.2880943   -3.06002024  -0.18848804
  -4.93552146   2.31725925 -11.48551446   2.32043887   4.07078132
   5.47131337  -3.09140141 -11.75043757  -3.16841127   0.99243037
  -6.56996257   2.70313314  -4.64884226]

b_cg = shrink_theta(z_cg, strue, w_cg, sk, dj)
fig = plt.figure(figsize = (12,6))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

ax1.scatter(btrue, b_cg)
mpl_utils.plot_diag(ax1)
ax1.set_xlabel("True b")
ax1.set_ylabel("Estimated b")

niter   = cg_min.nit
allobjs = np.zeros(niter + 1)
for i, params in enumerate(cg_min.allvecs):
    z_it = params[:p]
    a_it = params[2*p:]
    lj_it = params[p:2*p]
    w_it = softmax(a_it, base = softmax_base)
    #alllagrangian[i] = penalty_operator_lagrangian(z_it, w_it, strue, sk, dj, lj_it, btrue)
    allobjs[i] = objective_numeric_lagrangian(params, strue, sk, dj, btrue, p, k, softmax_base)
ax2.plot(np.arange(niter + 1), allobjs)
ax2.set_xlabel("# iterations")
ax2.set_ylabel(r"h(w, $\theta$)")

plt.tight_layout()

plt.show()

Here, I show the evolution of $w_k$ and estimated $\mathbf{b}$ over the iterations.

subplot_h = 1.8
nstep = 5
nplot = int(niter / nstep) + 1
ncol  = 4
nrow  = int(nplot / ncol + 1) if nplot%ncol != 0 else int(nplot / ncol)
figw  = ncol * subplot_h + (ncol - 1) * 0.3 + 1.2
figh  = nrow * subplot_h + (nrow - 1) * 0.3 + 1.5
figscale = 12.0 / figw

bgcolor = '#F0F0F0'
highlight_color = '#EE6868'
subdue_color = '#848f94'
text_color = '#69767c'

fig = plt.figure(figsize = (figw * figscale, figh * figscale))
axmain = fig.add_subplot(111)

for i in range(nplot):
    ax  = fig.add_subplot(nrow, ncol, i + 1)
    itr = i * nstep
    
    params = cg_min.allvecs[itr]
    z_it = params[:p]
    a_it = params[2*p:]
    #lj_it = params[p:2*p]
    w_it = softmax(a_it, base = softmax_base)
    b_it = shrink_theta(z_it, strue, w_it, sk, dj)
    
    ax.scatter(btrue, b_it, s=10)
    mpl_utils.plot_diag(ax)
    
    wtext = r'$w_k$ = ' + ', '.join([f"{w:.2f}" for w in w_it])
    itrtext = f"Iteration {itr}"
    
    ax.text(0.05, 0.85, wtext, va='top', ha='left', 
            transform=ax.transAxes, color = text_color, fontsize = 10)
    ax.text(0.05, 0.95, itrtext, va='top', ha='left', 
            transform=ax.transAxes, color = text_color, fontsize = 10)
    ax.tick_params(bottom = False, top = False, left = False, right = False,
                   labelbottom = False, labeltop = False, labelleft = False, labelright = False)
    
    ax.set_facecolor(bgcolor)
    for side, border in ax.spines.items():
        border.set_visible(False)
    if i < ncol:
        ax.tick_params(top = True, labeltop = True, color = bgcolor, width = 5)
        #ax.set_xticks(np.log10([0.001, 0.01, 0.1, 1.0]))
    if i%ncol == 0:
        ax.tick_params(left = True, labelleft = True, color = bgcolor, width = 5)
    #ax.set_ylim(-0.1, 2.1)
    
axmain.tick_params(bottom = False, top = False, left = False, right = False,
                   labelbottom = False, labeltop = False, labelleft = False, labelright = False)
for side, border in axmain.spines.items():
    border.set_visible(False)
axmain.set_ylabel(r'Estimated b', labelpad = 40, color = text_color)
axmain.set_xlabel(r'True b', labelpad = 50, color = text_color)
axmain.xaxis.set_label_position('top') 


plt.tight_layout()
# plt.savefig(f'../plots/{fileprefix}.pdf', bbox_inches='tight')
# plt.savefig(f'../plots/{fileprefix}.png', bbox_inches='tight')
plt.show()

Why $w_k$ differ from true values?

We found that the estimated $\hat{w}_k$ differs from the true $w_k$. As shown below, this is because the objective function has a minimum at this value. This minimum can be avoided by initializing $w_k$ such that $w_1$ is above a certain threshold. However, while writing this, I still do not know the origin of this minimum.

In the following, I calculate the objective function $h(\theta, a)$ with fixed $\boldsymbol{\theta}$ and fixed $w_3 = 0.001$, while we vary $w_1$ and $w_2$ such that $w_1 + w_2 + w_3 = 1$. In the plot below, I show that the local minimum appears irrespective of the four different choices of $\boldsymbol{\theta}$ used for the plot (see figure legend). The "true" value of $\boldsymbol{\theta}$ is obtained from $M_{w, \sigma}^{-1}(\mathbf{b})$ using Newton-Raphson inversion.

z_newraph = unshrink_b(btrue, strue, wtrue, sk, dj)
468.03779262835235
2869.763321128963
496.9212094358775
891.5663335803786
515.5975643334306
698.5235410382788
547.5737188775904
561.7462094292475
556.0885850660974
555.6462276502863
555.6414410883494

def get_obj_list(aseq, z, std, sk, dj, b, p, k, softmax_base):
    h_seq = list()
    for ak in aseq:
        params = np.concatenate([z, lgrng, ak])
        h = objective_numeric_lagrangian(params, strue, sk, dj, btrue, p, k, softmax_base)
        h_seq.append(h)
    return np.array(h_seq)

def get_wk_text(wk):
    wstr = ", ".join([f"{w:.2f}" for w in wk[:2]])
    return f"({wstr})"

nseq  = 101
w3 = 0.001
w1seq = np.linspace(0, 1 - w3, nseq)
w2seq = 1 - w3 - w1seq
w_seq = [np.array([w1, w2, w3]) for w1, w2 in zip(w1seq, w2seq)]
a_seq = [np.log(w + 1e-8) / np.log(softmax_base) for w in w_seq]
w_seq = [softmax(ak, base = softmax_base) for ak in a_seq]

h_seq_random = get_obj_list(a_seq, np.random.rand(p), strue, sk, dj, btrue, p, k, softmax_base)
h_seq_zeros  = get_obj_list(a_seq, np.zeros(p), strue, sk, dj, btrue, p, k, softmax_base)
h_seq_btrue  = get_obj_list(a_seq, btrue, strue, sk, dj, btrue, p, k, softmax_base)
h_seq_ttrue  = get_obj_list(a_seq, z_newraph, strue, sk, dj, btrue, p, k, softmax_base)
    
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(np.arange(nseq), h_seq_random, label = r"Random")
ax1.plot(np.arange(nseq), h_seq_zeros,  label = r"Zeros")
ax1.plot(np.arange(nseq), h_seq_btrue,  label = r"$\theta = b$")
ax1.plot(np.arange(nseq), h_seq_ttrue,  label = r"True ($\theta = M^{-1}(b)$)")

legendtitle = r"Choice of $\theta$"
legend1 = ax1.legend(loc = 'upper left', bbox_to_anchor = (0.02, 0.98), frameon = False, title = legendtitle)
legend1._legend_box.align = "left"
#lframe = legend1.get_frame()
#lframe.set_linewidth(0)

#ax1.hist(mean)
#ax1.legend()
ax1.set_ylabel(r"$h(\theta, a)$")
ax1.set_xlabel(r"($w_1, w_2$) with $w_3 =$ " + f"{w3:.4f}")


xtickpos = np.arange(0, nseq, 10)
xticklabels = [get_wk_text(w_seq[i]) for i in xtickpos]
ax1.set_xticks(xtickpos)
ax1.set_xticklabels(xticklabels, rotation = 90)

plt.show()

Initialize $w_1$ above 0.5

Following the previous analysis, we expect to descend towards true $w_k$ if we initialize $w_1 > 0.5$. In the following, we check this hypothesis by initializing $w = [0.7, 0.15, 0.15]$.

#print ("Using Newton-Raphson inversion to obtain initial 𝜃 from true w and b")
#z = unshrink_b(btrue, strue, wtrue, sk, dj)
z = btrue.copy()
softmax_base = np.exp(1)
winit, _ = initialize_ash_prior(k, sparsity = 0.7)
akinit = np.log(winit + 1e-8) / np.log(softmax_base)
lgrng = np.ones(btrue.shape[0]) * 1.0
wk  = softmax(akinit, base = softmax_base)

initparams = np.concatenate([z, lgrng, akinit])
cg_min = sp_optimize.minimize(objective_numeric_lagrangian, initparams,
                              args = (strue, sk, dj, btrue, p, k, softmax_base),
                              method = 'CG',
                              options = {'disp': True, 'maxiter': 500, 'return_all': True}
                             )

z_cg = cg_min.x[:p]
a_cg = cg_min.x[2*p:]
w_cg = softmax(a_cg, base = softmax_base)
lj_cg = cg_min.x[p:2*p]
print ("True w:", wtrue)
print("Est. w: ", w_cg)

b_cg = shrink_theta(z_cg, strue, w_cg, sk, dj)
fig = plt.figure(figsize = (12,6))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

ax1.scatter(btrue, b_cg)
mpl_utils.plot_diag(ax1)
ax1.set_xlabel("True b")
ax1.set_ylabel("Estimated b")

niter   = cg_min.nit
allobjs = np.zeros(niter + 1)
for i, params in enumerate(cg_min.allvecs):
    z_it = params[:p]
    a_it = params[2*p:]
    lj_it = params[p:2*p]
    w_it = softmax(a_it, base = softmax_base)
    #alllagrangian[i] = penalty_operator_lagrangian(z_it, w_it, strue, sk, dj, lj_it, btrue)
    allobjs[i] = objective_numeric_lagrangian(params, strue, sk, dj, btrue, p, k, softmax_base)
ax2.plot(np.arange(niter + 1), allobjs)
ax2.set_xlabel("# iterations")
ax2.set_ylabel(r"h(w, $\theta$)")

plt.tight_layout()
plt.show()
Warning: Maximum number of iterations has been exceeded.
         Current function value: 0.006265
         Iterations: 500
         Function evaluations: 1191748
         Gradient evaluations: 1187
True w: [0.8 0.1 0.1]
Est. w:  [0.62024743 0.00100011 0.37875247]

# ========================
# Plot progress of optimization
# ========================

subplot_h = 1.8
niter = 200
nstep = 20
nplot = int(niter / nstep) + 1
ncol  = 4
nrow  = int(nplot / ncol + 1) if nplot%ncol != 0 else int(nplot / ncol)
figw  = ncol * subplot_h + (ncol - 1) * 0.3 + 1.2
figh  = nrow * subplot_h + (nrow - 1) * 0.3 + 1.5
figscale = 12.0 / figw

bgcolor = '#F0F0F0'
highlight_color = '#EE6868'
subdue_color = '#848f94'
text_color = '#69767c'

fig = plt.figure(figsize = (figw * figscale, figh * figscale))
axmain = fig.add_subplot(111)

for i in range(nplot):
    ax  = fig.add_subplot(nrow, ncol, i + 1)
    itr = i * nstep
    
    params = cg_min.allvecs[itr]
    z_it = params[:p]
    a_it = params[2*p:]
    #lj_it = params[p:2*p]
    w_it = softmax(a_it, base = softmax_base)
    b_it = shrink_theta(z_it, strue, w_it, sk, dj)
    
    ax.scatter(btrue, b_it, s=10)
    mpl_utils.plot_diag(ax)
    
    wtext = r'$w_k$ = ' + ', '.join([f"{w:.2f}" for w in w_it])
    itrtext = f"Iteration {itr}"
    
    ax.text(0.05, 0.85, wtext, va='top', ha='left', 
            transform=ax.transAxes, color = text_color, fontsize = 10)
    ax.text(0.05, 0.95, itrtext, va='top', ha='left', 
            transform=ax.transAxes, color = text_color, fontsize = 10)
    ax.tick_params(bottom = False, top = False, left = False, right = False,
                   labelbottom = False, labeltop = False, labelleft = False, labelright = False)
    
    ax.set_facecolor(bgcolor)
    for side, border in ax.spines.items():
        border.set_visible(False)
    if i < ncol:
        ax.tick_params(top = True, labeltop = True, color = bgcolor, width = 5)
        #ax.set_xticks(np.log10([0.001, 0.01, 0.1, 1.0]))
    if i%ncol == 0:
        ax.tick_params(left = True, labelleft = True, color = bgcolor, width = 5)
    #ax.set_ylim(-0.1, 2.1)
    
axmain.tick_params(bottom = False, top = False, left = False, right = False,
                   labelbottom = False, labeltop = False, labelleft = False, labelright = False)
for side, border in axmain.spines.items():
    border.set_visible(False)
axmain.set_ylabel(r'Estimated b', labelpad = 40, color = text_color)
axmain.set_xlabel(r'True b', labelpad = 50, color = text_color)
axmain.xaxis.set_label_position('top') 


plt.tight_layout()
# plt.savefig(f'../plots/{fileprefix}.pdf', bbox_inches='tight')
# plt.savefig(f'../plots/{fileprefix}.png', bbox_inches='tight')
plt.show()