Effect of $\sigma_b$ and $\sigma_w$ in EBMR with product of normals
- About
- Toy example
- Fix initial values of $\sigma^2$, $\sigma_b^2$ and $\sigma_w^2$
- Dependence on $\sigma_b$ and $\sigma_w$
import numpy as np
import pandas as pd
from scipy import linalg as sc_linalg
import matplotlib.pyplot as plt
import sys
sys.path.append("../../ebmrPy/")
from inference.ebmr import EBMR
from inference import f_elbo
from inference import f_sigma
from inference import penalized_em
from utils import log_density
sys.path.append("../../utils/")
import mpl_stylesheet
from matplotlib import cm
from matplotlib import ticker as plticker
from mpl_toolkits.axes_grid1 import make_axes_locatable
mpl_stylesheet.banskt_presentation(fontfamily = 'latex-clearsans', fontsize = 18, colors = 'banskt', dpi = 72)
def standardize(X):
Xnorm = (X - np.mean(X, axis = 0))
#Xstd = Xnorm / np.std(Xnorm, axis = 0)
Xstd = Xnorm / np.sqrt((Xnorm * Xnorm).sum(axis = 0))
return Xstd
def trend_data(n, p, bval = 1.0, sd = 1.0, seed=100):
np.random.seed(seed)
X = np.zeros((n, p))
for i in range(p):
X[i:n, i] = np.arange(1, n - i + 1)
#X = standardize(X)
btrue = np.zeros(p)
idx = int(n / 3)
btrue[idx] = bval
btrue[idx + 1] = -bval
y = np.dot(X, btrue) + np.random.normal(0, sd, n)
# y = y / np.std(y)
return X, y, btrue
n = 100
p = 200
bval = 8.0
sd = 2.0
X, y, btrue = trend_data(n, p, bval = bval, sd = sd)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(np.arange(n), np.dot(X, btrue), label = "Xb")
ax1.scatter(np.arange(n), y, edgecolor = 'black', facecolor='white', label = "Xb + e")
ax1.legend()
ax1.set_xlabel("Sample index")
ax1.set_ylabel("y")
plt.show()
Fix initial values of $\sigma^2$, $\sigma_b^2$ and $\sigma_w^2$
First, I will get an optimal fit, using point estimates of $\mathbf{b}$ and $\mathbf{w}$ (without using the contribution of their variance). And then initialize the second optimization (including the contribution of the variances of $\mathbf{b}$ and $\mathbf{w}$) from the optimal initial values of all parameters. In this second optimization, I will fix the parameters $\sigma^2$, $\sigma_b^2$ and $\sigma_w^2$.
def ridge_mll(X, y, s2, sb2, W):
n, p = X.shape
Xscale = np.dot(X, np.diag(W))
XWWtXt = np.dot(Xscale, Xscale.T)
sigmay = s2 * (np.eye(n) + sb2 * XWWtXt)
muy = np.zeros((n, 1))
return log_density.mgauss(y.reshape(-1,1), muy, sigmay)
def grr_step(X, y, s2, sb2, muW, varW, XTX, XTy, useVW=True):
n, p = X.shape
W = np.diag(muW)
WtXtXW = np.linalg.multi_dot([W.T, XTX, W])
VW = np.diag(XTX) * np.diag(varW) if useVW else np.zeros(p)
sigmabinv = (WtXtXW + np.diag(VW) + np.eye(p) * s2 / sb2) / s2
sigmab = np.linalg.inv(sigmabinv)
mub = np.linalg.multi_dot([sigmab, W.T, XTy]) / s2
XWmu = np.linalg.multi_dot([X, W, mub])
mub2 = np.square(mub)
s2 = (np.sum(np.square(y - XWmu)) \
+ np.dot((WtXtXW + np.diag(VW)), sigmab).trace() \
+ np.sum(mub2 * VW)) / n
sb2 = (np.sum(mub2) + sigmab.trace()) / p
return s2, sb2, mub, sigmab
def elbo(X, y, s2, sb2, sw2, mub, sigmab, Wbar, varW, XTX, useVW=True):
'''
Wbar is a vector which contains the diagonal elements of the diagonal matrix W
W = diag_matrix(Wbar)
Wbar = diag(W)
--
VW is a vector which contains the diagonal elements of the diagonal matrix V_w
'''
n, p = X.shape
VW = np.diag(XTX) * np.diag(varW) if useVW else np.zeros(p)
elbo = c_func(n, p, s2, sb2, sw2) \
+ h1_func(X, y, s2, sb2, sw2, mub, Wbar, VW) \
+ h2_func(p, s2, sb2, sw2, XTX, Wbar, sigmab, varW, VW)
return elbo
def c_func(n, p, s2, sb2, sw2):
val = p
val += - 0.5 * n * np.log(2.0 * np.pi * s2)
val += - 0.5 * p * np.log(sb2)
val += - 0.5 * p * np.log(sw2)
return val
def h1_func(X, y, s2, sb2, sw2, mub, Wbar, VW):
XWmu = np.linalg.multi_dot([X, np.diag(Wbar), mub])
val1 = - (0.5 / s2) * np.sum(np.square(y - XWmu))
val2 = - 0.5 * np.sum(np.square(mub) * ((VW / s2) + (1 / sb2)))
val3 = - 0.5 * np.sum(np.square(Wbar)) / sw2
val = val1 + val2 + val3
return val
def h2_func(p, s2, sb2, sw2, XTX, Wbar, sigmab, sigmaw, VW):
(sign, logdetS) = np.linalg.slogdet(sigmab)
(sign, logdetV) = np.linalg.slogdet(sigmaw)
W = np.diag(Wbar)
WtXtXW = np.linalg.multi_dot([W.T, XTX, W])
val = 0.5 * logdetS + 0.5 * logdetV
val += - 0.5 * np.trace(sigmab) / sb2 - 0.5 * np.trace(sigmaw) / sw2
val += - 0.5 * np.dot(WtXtXW + np.diag(VW), sigmab).trace() / s2
return val
def ebmr_WB2(X, y,
s2_init = 1.0, sb2_init = 1.0, sw2_init = 1.0,
binit = None, winit = None,
sigmab_init = None, sigmaw_init = None,
use_wb_variance = True,
ignore_b_update = False,
ignore_w_update = False,
ignore_s2_update = False,
ignore_s2_sb2_update = False,
ignore_sw2_update = False,
max_iter = 10000, tol = 1e-8
):
XTX = np.dot(X.T, X)
XTy = np.dot(X.T, y)
n_samples, n_features = X.shape
elbo_path = np.zeros(max_iter + 1)
mll_path = np.zeros(max_iter + 1)
'''
Iteration 0
'''
niter = 0
s2 = s2_init
sb2 = sb2_init
sw2 = sw2_init
mub = np.ones(n_features) if binit is None else binit
muw = np.ones(n_features) if winit is None else winit
sigmab = np.zeros((n_features, n_features)) if sigmab_init is None else sigmab_init
sigmaw = np.zeros((n_features, n_features)) if sigmaw_init is None else sigmaw_init
elbo_path[0] = -np.inf
mll_path[0] = -np.inf
for itn in range(1, max_iter + 1):
'''
GRR for b
'''
if not ignore_b_update:
if ignore_s2_update:
#print ("Updating b without s2")
__, sb2, mub, sigmab = grr_step(X, y, s2, sb2, muw, sigmaw, XTX, XTy, useVW=use_wb_variance)
elif ignore_s2_sb2_update:
#print ("Updating b without s2 and sb2")
__, ___, mub, sigmab = grr_step(X, y, s2, sb2, muw, sigmaw, XTX, XTy, useVW=use_wb_variance)
else:
#print ("Updating b")
s2, sb2, mub, sigmab = grr_step(X, y, s2, sb2, muw, sigmaw, XTX, XTy, useVW=use_wb_variance)
'''
GRR for W
'''
if not ignore_w_update:
if ignore_sw2_update:
#print ("Updating w without sw2")
__, ___, muw, sigmaw = grr_step(X, y, s2, sw2, mub, sigmab, XTX, XTy, useVW=use_wb_variance)
else:
#print ("Updating w")
__, sw2, muw, sigmaw = grr_step(X, y, s2, sw2, mub, sigmab, XTX, XTy, useVW=use_wb_variance)
'''
Convergence
'''
niter += 1
elbo_path[itn] = elbo(X, y, s2, sb2, sw2, mub, sigmab, muw, sigmaw, XTX, useVW=use_wb_variance)
mll_path[itn] = ridge_mll(X, y, s2, sb2, muw)
if elbo_path[itn] - elbo_path[itn - 1] < tol: break
#if mll_path[itn] - mll_path[itn - 1] < tol: break
return s2, sb2, sw2, mub, sigmab, muw, sigmaw, niter, elbo_path[:niter + 1], mll_path[:niter + 1]
m1 = ebmr_WB2(X, y, use_wb_variance = False)
# s2, sb2, sw2, mub, sigmab, W, sigmaW, niter, elbo_path, mll_path = m1
m2 = ebmr_WB2(X, y, use_wb_variance=True,
s2_init = m1[0],
sb2_init = m1[1],
sw2_init = m1[2],
binit = m1[3],
sigmab_init = m1[4],
winit = m1[5],
sigmaw_init = m1[6],
ignore_s2_sb2_update = True,
ignore_sw2_update = True,
tol = 1e-3,
)
s2, sb2, sw2, mub, sigmab, W, sigmaW, niter, elbo_path, mll_path = m2
bpred = mub * W
ypred = np.dot(X, bpred)
fig = plt.figure(figsize = (12, 12))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
#ax4 = fig.add_subplot(224)
yvals = np.log(np.max(elbo_path[1:]) - elbo_path[1:] + 1)
ax1.scatter(np.arange(niter), yvals, edgecolor = 'black', facecolor='white')
ax1.plot(np.arange(niter), yvals)
ax1.set_xlabel("Iterations")
ax1.set_ylabel("log (max(ELBO) - ELBO[itn] + 1)")
ax2.scatter(np.arange(n), y, edgecolor = 'black', facecolor='white')
ax2.plot(np.arange(n), ypred, color = 'salmon', label="Predicted")
ax2.plot(np.arange(n), np.dot(X, btrue), color = 'dodgerblue', label="True")
ax2.legend()
ax2.set_xlabel("Sample Index")
ax2.set_ylabel("y")
ax3.scatter(np.arange(p), btrue, edgecolor = 'black', facecolor='white', label="True")
ax3.scatter(np.arange(p), bpred, label="Predicted")
ax3.legend()
ax3.set_xlabel("Predictor Index")
ax3.set_ylabel("wb")
# nstep = min(80, niter - 2)
# ax4.scatter(np.arange(nstep), mll_path[-nstep:], edgecolor = 'black', facecolor='white', label="Evidence")
# ax4.plot(np.arange(nstep), elbo_path[-nstep:], label="ELBO")
# ax4.legend()
# ax4.set_xlabel("Iterations")
# ax4.set_ylabel("ELBO / Evidence")
plt.tight_layout()
plt.show()
Dependence on $\sigma_b$ and $\sigma_w$
Next, I look at the contour of max(ELBO) at different values of $\sigma_b$ and $\sigma_w$. The optimization at each set of $(\sigma_b, \sigma_w)$ was performed with a fixed $\sigma^2 = 3.89$ (obtained from the previous optimization) and iterations were stopped with a tolerance of 1e-3.
def optimize_elbo(X, y, s2_init, sb2_init, sw2_init, binit, winit, sigmab_init, sigmaw_init):
mres = ebmr_WB2(X, y, use_wb_variance=True,
s2_init = s2_init,
sb2_init = sb2_init,
sw2_init = sw2_init,
binit = binit,
sigmab_init = sigmab_init,
winit = winit,
sigmaw_init = sigmaw_init,
ignore_s2_sb2_update = True,
ignore_sw2_update = True,
tol = 1e-3, max_iter = 500,
)
s2, sb2, sw2, mub, sigmab, W, sigmaW, niter, elbo_path, mll_path = mres
return elbo_path[-1]
def plot_contours(ax, X, Y, Z, beta, norm, cstep = 10,
xlabel = "", ylabel = "", zlabel = "",
showbeta=False, showZmax=False):
zmin = np.min(Z) - 1 * np.std(Z)
zmax = np.max(Z) + 1 * np.std(Z)
ind = np.unravel_index(np.argmax(Z, axis=None), Z.shape)
levels = np.linspace(zmin, zmax, 200)
clevels = np.linspace(zmin, zmax, 20)
cmap = cm.YlOrRd_r
if norm:
cset1 = ax.contourf(X, Y, Z, levels, norm = norm,
cmap=cm.get_cmap(cmap, len(levels) - 1))
else:
cset1 = ax.contourf(X, Y, Z, levels,
cmap=cm.get_cmap(cmap, len(levels) - 1))
cset2 = ax.contour(X, Y, Z, clevels, colors='k')
for c in cset2.collections:
c.set_linestyle('solid')
ax.set_aspect("equal")
if showbeta:
ax.scatter(beta[0], beta[1], color = 'blue', s = 100)
if showZmax:
ax.scatter(X[ind[1]], Y[ind[0]], color = 'k', s = 100)
if xlabel:
ax.set_xlabel(xlabel)
if ylabel:
ax.set_ylabel(ylabel)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.2)
cbar = plt.colorbar(cset1, cax=cax)
ytickpos = np.arange(int(zmin / cstep) * cstep, zmax, cstep)
cbar.set_ticks(ytickpos)
if zlabel:
cax.set_ylabel(zlabel)
#loc = plticker.AutoLocator()
#ax.xaxis.set_major_locator(loc)
#ax.yaxis.set_major_locator(loc)
k = 25
sb2_vals = np.logspace(-5, 1, k)
sw2_vals = np.logspace(-5, 1, k)
s2_init = m1[0]
binit = m1[3]
winit = m1[5]
sigmab_init = m1[4]
sigmaw_init = m1[6]
ELBO = np.zeros((k, k))
for i, sb2_init in enumerate(sb2_vals):
for j, sw2_init in enumerate(sw2_vals):
ELBO[j, i] = optimize_elbo(X, y, s2_init, sb2_init, sw2_init,
binit, winit, sigmab_init, sigmaw_init)
fig = plt.figure(figsize = (10, 8))
ax1 = fig.add_subplot(111)
norm = cm.colors.Normalize(vmin=np.min(ELBO), vmax=np.max(ELBO))
#norm = cm.colors.TwoSlopeNorm(vmin=np.min(ELBO), vcenter=np.max(ELBO)-100, vmax=np.max(ELBO))
plot_contours(ax1, np.log(sb2_vals), np.log(sw2_vals), ELBO, [-2.3, -6.9],
norm = norm, cstep = 100,
showbeta = False, showZmax = False,
xlabel = r"$\ln(\sigma_b^2)$",
ylabel = r"$\ln(\sigma_w^2)$",
zlabel = "ln(ELBO)",
)
plt.tight_layout()
plt.show()