Sequence of updates in EBMR with product of normals
About
Here, I am simply checking whether the sequence of updates has any effect on the optimization in the variational approximation of EBMR. Earlier, I found that the variational approximation for the product of two normals leads to severe overfitting (see here).
import numpy as np
import pandas as pd
from scipy import linalg as sc_linalg
import matplotlib.pyplot as plt
import sys
sys.path.append("../../ebmrPy/")
from inference.ebmr import EBMR
from inference import f_elbo
from inference import f_sigma
from inference import penalized_em
from utils import log_density
sys.path.append("../../utils/")
import mpl_stylesheet
mpl_stylesheet.banskt_presentation(fontfamily = 'latex-clearsans', fontsize = 18, colors = 'banskt', dpi = 72)
def standardize(X):
Xnorm = (X - np.mean(X, axis = 0))
#Xstd = Xnorm / np.std(Xnorm, axis = 0)
Xstd = Xnorm / np.sqrt((Xnorm * Xnorm).sum(axis = 0))
return Xstd
def trend_data(n, p, bval = 1.0, sd = 1.0, seed=100):
np.random.seed(seed)
X = np.zeros((n, p))
for i in range(p):
X[i:n, i] = np.arange(1, n - i + 1)
btrue = np.zeros(p)
idx = int(n / 3)
btrue[idx] = bval
btrue[idx + 1] = -bval
y = np.dot(X, btrue) + np.random.normal(0, sd, n)
# y = y / np.std(y)
return X, y, btrue
n = 100
p = 200
bval = 8.0
sd = 2.0
X, y, btrue = trend_data(n, p, bval = bval, sd = sd)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(np.arange(n), np.dot(X, btrue), label = "Xb")
ax1.scatter(np.arange(n), y, edgecolor = 'black', facecolor='white', label = "Xb + e")
ax1.legend()
ax1.set_xlabel("Sample index")
ax1.set_ylabel("y")
plt.show()
def ridge_mll(X, y, s2, sb2, W):
n, p = X.shape
Xscale = np.dot(X, np.diag(W))
XWWtXt = np.dot(Xscale, Xscale.T)
sigmay = s2 * (np.eye(n) + sb2 * XWWtXt)
muy = np.zeros((n, 1))
return log_density.mgauss(y.reshape(-1,1), muy, sigmay)
def grr_b(X, y, s2, sb2, Wbar, varWj, XTX, XTy):
n, p = X.shape
W = np.diag(Wbar)
WtXtXW = np.linalg.multi_dot([W.T, XTX, W])
VW = np.diag(XTX) * varWj
sigmabinv = (WtXtXW + np.diag(VW) + np.eye(p) * s2 / sb2) / s2
sigmab = np.linalg.inv(sigmabinv)
mub = np.linalg.multi_dot([sigmab, W.T, XTy]) / s2
XWmu = np.linalg.multi_dot([X, W, mub])
mub2 = np.square(mub)
s2 = (np.sum(np.square(y - XWmu)) \
+ np.dot((WtXtXW + np.diag(VW)), sigmab).trace() \
+ np.sum(mub2 * VW)) / n
sb2 = (np.sum(mub2) + sigmab.trace()) / p
return s2, sb2, mub, sigmab
def grr_W_old(X, y, s2, sw2, mub, sigmab, muWj, XTX, XTy):
n, p = X.shape
R = np.einsum('i,j->ij', mub, mub) + sigmab
XTXRjj = np.array([XTX[j, j] * R[j, j] for j in range(p)])
#wXTXRj = np.array([np.sum(muWj * XTX[:, j] * R[:, j]) - (muWj[j] * XTXRjj[j]) for j in range(p)])
sigmaWj2 = 1 / ((XTXRjj / s2) + (1 / sw2))
for j in range(p):
wXTXRj = np.sum(muWj * XTX[:, j] * R[:, j]) - (muWj[j] * XTXRjj[j])
muWj[j] = sigmaWj2[j] * (mub[j] * XTy[j] - 0.5 * wXTXRj) / s2
sw2 = np.sum(np.square(muWj) + sigmaWj2) / p
return sw2, muWj, sigmaWj2
def grr_W(X, y, s2, sw2, mub, sigmab, muWj, XTX, XTy):
n, p = X.shape
R = np.einsum('i,j->ij', mub, mub) + sigmab
XTXRjj = np.diag(XTX) * np.diag(R)
sigmaWj2inv = (XTXRjj / s2) + (1 / sw2)
wXTXRj = np.array([np.sum(muWj * XTX[:, j] * R[:, j]) - (muWj[j] * XTXRjj[j]) for j in range(p)])
sigmaWj2 = 1 / sigmaWj2inv
muWj = sigmaWj2 * (mub * XTy - wXTXRj) / s2
sw2 = np.sum(np.square(muWj) + sigmaWj2) / p
#sigmaWj2 = np.zeros(p)
return sw2, muWj, sigmaWj2
def elbo(X, y, s2, sb2, sw2, mub, sigmab, Wbar, varWj, XTX):
'''
Wbar is a vector which contains the diagonal elements of the diagonal matrix W
W = diag_matrix(Wbar)
Wbar = diag(W)
--
VW is a vector which contains the diagonal elements of the diagonal matrix V_w
'''
n, p = X.shape
VW = np.diag(XTX) * varWj
elbo = c_func(n, p, s2, sb2, sw2) \
+ h1_func(X, y, s2, sb2, sw2, mub, Wbar, VW) \
+ h2_func(p, s2, sb2, sw2, XTX, Wbar, sigmab, varWj, VW)
return elbo
def c_func(n, p, s2, sb2, sw2):
val = p
val += - 0.5 * n * np.log(2.0 * np.pi * s2)
val += - 0.5 * p * np.log(sb2)
val += - 0.5 * p * np.log(sw2)
return val
def h1_func(X, y, s2, sb2, sw2, mub, Wbar, VW):
XWmu = np.linalg.multi_dot([X, np.diag(Wbar), mub])
val1 = - (0.5 / s2) * np.sum(np.square(y - XWmu))
val2 = - 0.5 * np.sum(np.square(mub) * ((VW / s2) + (1 / sb2)))
val3 = - 0.5 * np.sum(np.square(Wbar)) / sw2
val = val1 + val2 + val3
return val
def h2_func(p, s2, sb2, sw2, XTX, Wbar, sigmab, varWj, VW):
(sign, logdetS) = np.linalg.slogdet(sigmab)
logdetV = np.sum(np.log(varWj))
W = np.diag(Wbar)
WtXtXW = np.linalg.multi_dot([W.T, XTX, W])
val = 0.5 * logdetS + 0.5 * logdetV
val += - 0.5 * np.trace(sigmab) / sb2 - 0.5 * np.sum(varWj) / sw2
val += - 0.5 * np.dot(WtXtXW + np.diag(VW), sigmab).trace() / s2
return val
def ebmr_WB1(X, y,
s2_init = 1.0, sb2_init = 1.0, sw2_init = 1.0,
binit = None, winit = None,
max_iter = 1000, tol = 1e-8
):
XTX = np.dot(X.T, X)
XTy = np.dot(X.T, y)
n_samples, n_features = X.shape
elbo_path = np.zeros(max_iter + 1)
mll_path = np.zeros(max_iter + 1)
'''
Iteration 0
'''
niter = 0
s2 = s2_init
sb2 = sb2_init
sw2 = sw2_init
mub = np.ones(n_features) if binit is None else binit
muWj = np.ones(n_features) if winit is None else winit
sigmab = np.zeros((n_features, n_features))
sigmaWj2 = np.zeros(n_features)
elbo_path[0] = -np.inf
mll_path[0] = -np.inf
for itn in range(1, max_iter + 1):
'''
GRR for b
'''
s2, sb2, mub, sigmab = grr_b(X, y, s2, sb2, muWj, sigmaWj2, XTX, XTy)
'''
GRR for W
'''
sw2, muWj, sigmaWj2 = grr_W(X, y, s2, sw2, mub, sigmab, muWj, XTX, XTy)
'''
Convergence
'''
niter += 1
elbo_path[itn] = elbo(X, y, s2, sb2, sw2, mub, sigmab, muWj, sigmaWj2, XTX)
mll_path[itn] = ridge_mll(X, y, s2, sb2, muWj)
if elbo_path[itn] - elbo_path[itn - 1] < tol: break
#if mll_path[itn] - mll_path[itn - 1] < tol: break
return s2, sb2, sw2, mub, sigmab, muWj, sigmaWj2, niter, elbo_path[:niter + 1], mll_path[:niter + 1]
And this leads to overfitting as we have seen previously.
m1 = ebmr_WB1(X, y)
s2, sb2, sw2, mub, sigmab, W, sigmaW, niter, elbo_path, mll_path = m1
bpred = mub * W
ypred = np.dot(X, bpred)
fig = plt.figure(figsize = (12, 12))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)
ax1.scatter(np.arange(niter-1), elbo_path[2:], edgecolor = 'black', facecolor='white')
ax1.plot(np.arange(niter-1), elbo_path[2:])
ax1.set_xlabel("Iterations")
ax1.set_ylabel("ELBO")
ax2.scatter(np.arange(n), y, edgecolor = 'black', facecolor='white')
ax2.plot(np.arange(n), ypred, color = 'salmon', label="Predicted")
ax2.plot(np.arange(n), np.dot(X, btrue), color = 'dodgerblue', label="True")
ax2.legend()
ax2.set_xlabel("Sample Index")
ax2.set_ylabel("y")
ax3.scatter(np.arange(p), btrue, edgecolor = 'black', facecolor='white', label="True")
ax3.scatter(np.arange(p), bpred, label="Predicted")
ax3.legend()
ax3.set_xlabel("Predictor Index")
ax3.set_ylabel("wb")
nstep = min(80, niter - 2)
ax4.scatter(np.arange(nstep), mll_path[-nstep:], edgecolor = 'black', facecolor='white')
ax4.plot(np.arange(nstep), elbo_path[-nstep:])
ax4.set_xlabel("Iterations")
ax4.set_ylabel("ELBO / Evidence")
plt.tight_layout()
plt.show()
def update_qbw(X, s2, sb2, sw2, mub, sigmab, Wbar, varWj, XTX, XTy):
n, p = X.shape
W = np.diag(Wbar)
WtXtXW = np.linalg.multi_dot([W.T, XTX, W])
VW = np.diag(XTX) * varWj
# update mub and sigmab
sigmabinv = (WtXtXW + np.diag(VW) + np.eye(p) * s2 / sb2) / s2
sigmab = np.linalg.inv(sigmabinv)
mub = np.linalg.multi_dot([sigmab, W.T, XTy]) / s2
# update Wbar and varWj
R = np.einsum('i,j->ij', mub, mub) + sigmab
XTXRjj = np.diag(XTX) * np.diag(R)
wXTXRj = np.array([np.sum(Wbar * XTX[:, j] * R[:, j]) - (Wbar[j] * XTXRjj[j]) for j in range(p)])
varWjinv = (XTXRjj / s2) + (1 / sw2)
varWj = 1 / varWjinv
for j in range(p):
wXTXRj = np.sum(Wbar * XTX[:, j] * R[:, j]) - (Wbar[j] * XTXRjj[j])
Wbar[j] = varWj[j] * (mub[j] * XTy[j] - wXTXRj) / s2
#Wbar = varWj * (mub * XTy - wXTXRj) / s2
return mub, sigmab, Wbar, varWj
def update_params(X, y, mub, sigmab, Wbar, varWj, XTX):
n, p = X.shape
W = np.diag(Wbar)
VW = np.diag(XTX) * varWj
WtXtXW = np.linalg.multi_dot([W.T, XTX, W])
XWmu = np.linalg.multi_dot([X, W, mub])
mub2 = np.square(mub)
# update the parameters
s2 = (np.sum(np.square(y - XWmu)) \
+ np.dot((WtXtXW + np.diag(VW)), sigmab).trace() \
+ np.sum(mub2 * VW)) / n
sb2 = np.sum(np.square(mub) + np.diag(sigmab)) / p
sw2 = np.sum(np.square(Wbar) + varWj) / p
return s2, sb2, sw2
def ebmr_WB2(X, y,
s2_init = 1.0, sb2_init = 1.0, sw2_init = 1.0,
binit = None, winit = None,
max_iter = 1000, tol = 1e-8
):
XTX = np.dot(X.T, X)
XTy = np.dot(X.T, y)
n_samples, n_features = X.shape
elbo_path = np.zeros(max_iter + 1)
mll_path = np.zeros(max_iter + 1)
'''
Iteration 0
'''
niter = 0
s2 = s2_init
sb2 = sb2_init
sw2 = sw2_init
mub = np.ones(n_features) if binit is None else binit
muWj = np.ones(n_features) if winit is None else winit
sigmab = np.zeros((n_features, n_features))
sigmaWj2 = np.zeros(n_features)
elbo_path[0] = -np.inf
mll_path[0] = -np.inf
for itn in range(1, max_iter + 1):
'''
Update q(b, w)
'''
mub, sigmab, muWj, sigmaWj2 = update_qbw(X, s2, sb2, sw2, mub, sigmab, muWj, sigmaWj2, XTX, XTy)
'''
Update s2, sb2, sw2
'''
s2, sb2, sw2 = update_params(X, y, mub, sigmab, muWj, sigmaWj2, XTX)
'''
Convergence
'''
niter += 1
elbo_path[itn] = elbo(X, y, s2, sb2, sw2, mub, sigmab, muWj, sigmaWj2, XTX)
mll_path[itn] = ridge_mll(X, y, s2, sb2, muWj)
if elbo_path[itn] - elbo_path[itn - 1] < tol: break
#if mll_path[itn] - mll_path[itn - 1] < tol: break
return s2, sb2, sw2, mub, sigmab, muWj, sigmaWj2, niter, elbo_path[:niter + 1], mll_path[:niter + 1]
However, there is still an overfitting.
m2 = ebmr_WB2(X, y)
s2, sb2, sw2, mub, sigmab, W, sigmaW, niter, elbo_path, mll_path = m2
bpred = mub * W
ypred = np.dot(X, bpred)
fig = plt.figure(figsize = (12, 12))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)
ax1.scatter(np.arange(niter-1), elbo_path[2:], edgecolor = 'black', facecolor='white')
ax1.plot(np.arange(niter-1), elbo_path[2:])
ax1.set_xlabel("Iterations")
ax1.set_ylabel("ELBO")
ax2.scatter(np.arange(n), y, edgecolor = 'black', facecolor='white')
ax2.plot(np.arange(n), ypred, color = 'salmon', label="Predicted")
ax2.plot(np.arange(n), np.dot(X, btrue), color = 'dodgerblue', label="True")
ax2.legend()
ax2.set_xlabel("Sample Index")
ax2.set_ylabel("y")
ax3.scatter(np.arange(p), btrue, edgecolor = 'black', facecolor='white', label="True")
ax3.scatter(np.arange(p), bpred, label="Predicted")
ax3.legend()
ax3.set_xlabel("Predictor Index")
ax3.set_ylabel("wb")
nstep = min(80, niter - 2)
ax4.scatter(np.arange(nstep), mll_path[-nstep:], edgecolor = 'black', facecolor='white', label="Evidence")
ax4.plot(np.arange(nstep), elbo_path[-nstep:], label="ELBO")
ax4.legend()
ax4.set_xlabel("Iterations")
ax4.set_ylabel("ELBO / Evidence")
plt.tight_layout()
plt.show()