Check ELBO for VEB with product of two normals in case of simple regression
About
Peter suggested to check whether the ELBO calculation reduces to the correct value if $p=1$. Here, I am calculating the ELBO for some given values of the hyperparameters and variational parameters. I am comparing the numerical values obtained from the ELBO code written earlier, and the simpler version of the ELBO using univariate normals. Both of them are same, at least numerically, in the limit of $p=1$.
import numpy as np
import pandas as pd
from scipy import linalg as sc_linalg
import matplotlib.pyplot as plt
import sys
sys.path.append("../../ebmrPy/")
from inference.ebmr import EBMR
from inference import f_elbo
from inference import f_sigma
from inference import penalized_em
from utils import log_density
sys.path.append("../../utils/")
import mpl_stylesheet
mpl_stylesheet.banskt_presentation(fontfamily = 'latex-clearsans', fontsize = 18, colors = 'banskt', dpi = 72)
def standardize(X):
Xnorm = (X - np.mean(X, axis = 0))
#Xstd = Xnorm / np.std(Xnorm, axis = 0)
Xstd = Xnorm / np.sqrt((Xnorm * Xnorm).sum(axis = 0))
return Xstd
def lasso_data(nsample, nvar, neff, errsigma, sb2 = 100, seed=100):
np.random.seed(seed)
X = np.random.normal(0, 1, nsample * nvar).reshape(nsample, nvar)
X = standardize(X)
btrue = np.zeros(nvar)
bidx = np.random.choice(nvar, neff , replace = False)
btrue[bidx] = np.random.normal(0, np.sqrt(sb2), neff)
y = np.dot(X, btrue) + np.random.normal(0, errsigma, nsample)
y = y - np.mean(y)
#y = y / np.std(y)
return X, y, btrue
def lims_xy(ax):
lims = [
np.min([ax.get_xlim(), ax.get_ylim()]), # min of both axes
np.max([ax.get_xlim(), ax.get_ylim()]), # max of both axes
]
return lims
def plot_diag(ax):
lims = lims_xy(ax)
ax.plot(lims, lims, ls='dotted', color='gray')
n = 100
p = 1
peff = 1
sb2 = 100.0
sd = 2.0
X, y, btrue = lasso_data(n, p, peff, sd, sb2)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(np.dot(X,btrue), y)
plot_diag(ax1)
ax1.set_xlabel("Xb")
ax1.set_ylabel("y")
plt.show()
def elbo(X, y, s2, sb2, sw2, mub, sigmab, Wbar, varW, XTX, useVW=True):
'''
Wbar is a vector which contains the diagonal elements of the diagonal matrix W
W = diag_matrix(Wbar)
Wbar = diag(W)
--
VW is a vector which contains the diagonal elements of the diagonal matrix V_w
'''
n, p = X.shape
VW = np.diag(XTX) * np.diag(varW) if useVW else np.zeros(p)
elbo = c_func(n, p, s2, sb2, sw2) \
+ h1_func(X, y, s2, sb2, sw2, mub, Wbar, VW) \
+ h2_func(p, s2, sb2, sw2, XTX, Wbar, sigmab, varW, VW)
return elbo
def c_func(n, p, s2, sb2, sw2):
val = p
val += - 0.5 * n * np.log(2.0 * np.pi * s2)
val += - 0.5 * p * np.log(sb2)
val += - 0.5 * p * np.log(sw2)
return val
def h1_func(X, y, s2, sb2, sw2, mub, Wbar, VW):
XWmu = np.linalg.multi_dot([X, np.diag(Wbar), mub])
val1 = - (0.5 / s2) * np.sum(np.square(y - XWmu))
val2 = - 0.5 * np.sum(np.square(mub) * ((VW / s2) + (1 / sb2)))
val3 = - 0.5 * np.sum(np.square(Wbar)) / sw2
val = val1 + val2 + val3
return val
def h2_func(p, s2, sb2, sw2, XTX, Wbar, sigmab, sigmaw, VW):
(sign, logdetS) = np.linalg.slogdet(sigmab)
(sign, logdetV) = np.linalg.slogdet(sigmaw)
W = np.diag(Wbar)
WtXtXW = np.linalg.multi_dot([W.T, XTX, W])
val = 0.5 * logdetS + 0.5 * logdetV
val += - 0.5 * np.trace(sigmab) / sb2 - 0.5 * np.trace(sigmaw) / sw2
val += - 0.5 * np.dot(WtXtXW + np.diag(VW), sigmab).trace() / s2
return val
def KL_qp_mvn(p, s2, M, S):
(sign, logdetS) = np.linalg.slogdet(S)
KL = 0.5 * (np.dot(M.T, M) + np.trace(S)) / s2
KL += - 0.5 * logdetS
KL += - 0.5 * p + 0.5 * p * np.log(s2)
return KL
def elbo_full(X, y, s2, sb2, sw2, mub, sigmab, Wbar, varW, XTX, useVW=True):
n, p = X.shape
VW = np.diag(XTX) * np.diag(varW) if useVW else np.zeros(p)
KLqb = KL_qp_mvn(p, sb2, mub, sigmab)
KLqw = KL_qp_mvn(p, sw2, Wbar, varW)
XWmu = np.linalg.multi_dot([X, np.diag(Wbar), mub])
W = np.diag(Wbar)
WtXtXW = np.linalg.multi_dot([W.T, XTX, W])
t1 = - 0.5 * n * np.log(2.0 * np.pi * s2)
t2 = - (0.5 / s2) * np.sum(np.square(y - XWmu))
t3 = - 0.5 * np.sum(np.square(mub) * (VW / s2))
t4 = - 0.5 * np.dot(WtXtXW + np.diag(VW), sigmab).trace() / s2
Eqlnpy = t1 + t2 + t3 + t4
elbo = Eqlnpy - KLqb - KLqw
return elbo
def elbo_simple(X, y, s2, sb2, sw2, mub, sigmab2, muw, sigmaw2):
KLqb = KL_qp_normals(0, sb2, mub, sigmab2)
KLqw = KL_qp_normals(0, sw2, muw, sigmaw2)
Eb2 = mub * mub + sigmab2
Ew2 = muw * muw + sigmaw2
t1 = - 0.5 * n * np.log(2 * np.pi * s2)
bhat = np.repeat(mub * muw, p)
t2 = - 0.5 * np.sum(np.square(y - np.dot(X, bhat))) / s2
t3 = - 0.5 * np.sum(np.square(X)) * (Eb2 * Ew2 - mub * mub * muw * muw) / s2
Eqlnpy = t1 + t2 + t3
elbo = Eqlnpy - KLqb - KLqw
return elbo
def KL_qp_normals(m1, s1sq, m2, s2sq):
val = 0.5 * (np.log(s1sq / s2sq) + (s2sq / s1sq) - 1 + (np.square(m1 - m2) / s1sq))
return val
s2 = sd * sd
sw2 = 0.5 * 0.5
muw = 2.0
mub = btrue[0] / muw
sigmab2 = 0.2 * 0.2
sigmaw2 = 0.1 * 0.1
elbo_simple(X, y, s2, sb2, sw2, mub, sigmab2, muw, sigmaw2)
elbo_full(X, y, s2, sb2, sw2, np.array([mub]), np.eye(p) * sigmab2,
np.array([muw]), np.eye(p) * sigmaw2, np.dot(X.T, X))
elbo(X, y, s2, sb2, sw2, np.array([mub]), np.eye(p) * sigmab2,
np.array([muw]), np.eye(p) * sigmaw2, np.dot(X.T, X))