About

Peter suggested to check whether the ELBO calculation reduces to the correct value if $p=1$. Here, I am calculating the ELBO for some given values of the hyperparameters and variational parameters. I am comparing the numerical values obtained from the ELBO code written earlier, and the simpler version of the ELBO using univariate normals. Both of them are same, at least numerically, in the limit of $p=1$.

import numpy as np
import pandas as pd
from scipy import linalg as sc_linalg
import matplotlib.pyplot as plt

import sys
sys.path.append("../../ebmrPy/")
from inference.ebmr import EBMR
from inference import f_elbo
from inference import f_sigma
from inference import penalized_em
from utils import log_density

sys.path.append("../../utils/")
import mpl_stylesheet
mpl_stylesheet.banskt_presentation(fontfamily = 'latex-clearsans', fontsize = 18, colors = 'banskt', dpi = 72)

def standardize(X):
    Xnorm = (X - np.mean(X, axis = 0)) 
    #Xstd =  Xnorm / np.std(Xnorm, axis = 0)
    Xstd = Xnorm / np.sqrt((Xnorm * Xnorm).sum(axis = 0))
    return Xstd

def lasso_data(nsample, nvar, neff, errsigma, sb2 = 100, seed=100):
    np.random.seed(seed)
    X = np.random.normal(0, 1, nsample * nvar).reshape(nsample, nvar)
    X = standardize(X)
    btrue = np.zeros(nvar)
    bidx = np.random.choice(nvar, neff , replace = False)
    btrue[bidx] = np.random.normal(0, np.sqrt(sb2), neff)
    y = np.dot(X, btrue) + np.random.normal(0, errsigma, nsample)
    y = y - np.mean(y)
    #y = y / np.std(y)
    return X, y, btrue

def lims_xy(ax):
    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]
    return lims

def plot_diag(ax):
    lims = lims_xy(ax)
    ax.plot(lims, lims, ls='dotted', color='gray')

n = 100
p = 1
peff = 1
sb2 = 100.0
sd = 2.0
X, y, btrue = lasso_data(n, p, peff, sd, sb2)

fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(np.dot(X,btrue), y)
plot_diag(ax1)
ax1.set_xlabel("Xb")
ax1.set_ylabel("y")

plt.show()

def elbo(X, y, s2, sb2, sw2, mub, sigmab, Wbar, varW, XTX, useVW=True):
    '''
    Wbar is a vector which contains the diagonal elements of the diagonal matrix W
    W = diag_matrix(Wbar)
    Wbar = diag(W)
    --
    VW is a vector which contains the diagonal elements of the diagonal matrix V_w
    '''
    n, p = X.shape
    VW = np.diag(XTX) * np.diag(varW) if useVW else np.zeros(p)
    elbo = c_func(n, p, s2, sb2, sw2) \
           + h1_func(X, y, s2, sb2, sw2, mub, Wbar, VW) \
           + h2_func(p, s2, sb2, sw2, XTX, Wbar, sigmab, varW, VW)
    return elbo


def c_func(n, p, s2, sb2, sw2):
    val  =   p
    val += - 0.5 * n * np.log(2.0 * np.pi * s2)
    val += - 0.5 * p * np.log(sb2)
    val += - 0.5 * p * np.log(sw2)
    return val


def h1_func(X, y, s2, sb2, sw2, mub, Wbar, VW):
    XWmu = np.linalg.multi_dot([X, np.diag(Wbar), mub])
    val1 = - (0.5 / s2) * np.sum(np.square(y - XWmu))
    val2 = - 0.5 * np.sum(np.square(mub) * ((VW / s2) + (1 / sb2)))
    val3 = - 0.5 * np.sum(np.square(Wbar)) / sw2
    val  = val1 + val2 + val3
    return val


def h2_func(p, s2, sb2, sw2, XTX, Wbar, sigmab, sigmaw, VW):
    (sign, logdetS) = np.linalg.slogdet(sigmab)
    (sign, logdetV) = np.linalg.slogdet(sigmaw)
    W = np.diag(Wbar)
    WtXtXW = np.linalg.multi_dot([W.T, XTX, W])
    val  =   0.5 * logdetS + 0.5 * logdetV
    val += - 0.5 * np.trace(sigmab) / sb2 - 0.5 * np.trace(sigmaw) / sw2
    val += - 0.5 * np.dot(WtXtXW + np.diag(VW), sigmab).trace() / s2
    return val

def KL_qp_mvn(p, s2, M, S):
    (sign, logdetS) = np.linalg.slogdet(S)
    KL = 0.5 * (np.dot(M.T, M) + np.trace(S)) / s2
    KL += - 0.5 * logdetS
    KL += - 0.5 * p + 0.5 * p * np.log(s2)
    return KL
    

def elbo_full(X, y, s2, sb2, sw2, mub, sigmab, Wbar, varW, XTX, useVW=True):
    n, p = X.shape
    VW = np.diag(XTX) * np.diag(varW) if useVW else np.zeros(p)
    KLqb = KL_qp_mvn(p, sb2, mub, sigmab)
    KLqw = KL_qp_mvn(p, sw2, Wbar, varW)
    XWmu = np.linalg.multi_dot([X, np.diag(Wbar), mub])
    W = np.diag(Wbar)
    WtXtXW = np.linalg.multi_dot([W.T, XTX, W])
    t1 = - 0.5 * n * np.log(2.0 * np.pi * s2)
    t2 = - (0.5 / s2) * np.sum(np.square(y - XWmu))
    t3 = - 0.5 * np.sum(np.square(mub) * (VW / s2))
    t4 = - 0.5 * np.dot(WtXtXW + np.diag(VW), sigmab).trace() / s2
    Eqlnpy = t1 + t2 + t3 + t4
    elbo = Eqlnpy - KLqb - KLqw
    return elbo
    


def elbo_simple(X, y, s2, sb2, sw2, mub, sigmab2, muw, sigmaw2):
    KLqb = KL_qp_normals(0, sb2, mub, sigmab2)
    KLqw = KL_qp_normals(0, sw2, muw, sigmaw2)
    Eb2 = mub * mub + sigmab2
    Ew2 = muw * muw + sigmaw2
    t1 = - 0.5 * n * np.log(2 * np.pi * s2)
    bhat = np.repeat(mub * muw, p)
    t2 = - 0.5 * np.sum(np.square(y - np.dot(X, bhat))) / s2
    t3 = - 0.5 * np.sum(np.square(X)) * (Eb2 * Ew2 - mub * mub * muw * muw) / s2
    Eqlnpy = t1 + t2 + t3
    elbo = Eqlnpy - KLqb - KLqw
    return elbo


def KL_qp_normals(m1, s1sq, m2, s2sq):
    val = 0.5 * (np.log(s1sq / s2sq) + (s2sq / s1sq) - 1 + (np.square(m1 - m2) / s1sq))
    return val
s2 = sd * sd
sw2 = 0.5 * 0.5
muw = 2.0
mub = btrue[0] / muw
sigmab2 = 0.2 * 0.2
sigmaw2 = 0.1 * 0.1
elbo_simple(X, y, s2, sb2, sw2, mub, sigmab2, muw, sigmaw2)
-229.70090826648755
elbo_full(X, y, s2, sb2, sw2, np.array([mub]), np.eye(p) * sigmab2, 
     np.array([muw]), np.eye(p) * sigmaw2, np.dot(X.T, X))
-229.70090826648755
elbo(X, y, s2, sb2, sw2, np.array([mub]), np.eye(p) * sigmab2, 
     np.array([muw]), np.eye(p) * sigmaw2, np.dot(X.T, X))
-229.70090826648755