About

I am trying to understand the problem in the VEB implementation of multiple regression with product of normals. I checked the analysis with single predictor ($p=1$) (link) and with two predictors ($p=2$) (link). Here, I will check with $n$ predictors with effect sizes sampled from a known prior.

import numpy as np
import pandas as pd
from scipy import linalg as sc_linalg
from scipy import special as sc_special
import matplotlib.pyplot as plt

import mpl_stylesheet
import mpl_utils
from matplotlib import cm
from matplotlib import ticker as plticker
from mpl_toolkits.axes_grid1 import make_axes_locatable
mpl_stylesheet.banskt_presentation(fontfamily = 'latex-clearsans', fontsize = 18, colors = 'banskt', dpi = 72)

Toy model

I defined a toy model with 50 predictors and 100 samples. I sampled the predictor from the standard normal distribution $\mathcal{N}(0, 1)$. The model used is described in the writeup.

I chose $\sigma = 50.0$, $\sigma_b = 2.0$ and $\sigma_w = 4.0$.

def variance_explained(y, ypred):
    ss_err = np.sum(np.square(y - ypred))
    ss_tot = np.sum(np.square(y - np.mean(y)))
    r2 = 1 - (ss_err / ss_tot)
    return r2


def prod_norm_prior_pdf(z, s1, s2):
    x = np.abs(z) / s1 / s2
    prob = sc_special.kn(0, x) / (np.pi * s1 * s2)
    return prob


def normal_logdensity_onesample(z, m, s2):
    logdensity = - 0.5 * np.log(2 * np.pi * s2) \
                 - 0.5 * (z - m) * (z - m) / s2
    return logdensity


def normal_logdensity(y, mean, sigma2):
    n = y.shape[0]
    logdensity = - 0.5 * n * np.log(2 * np.pi * sigma2) \
                 - 0.5 * np.sum(np.square(y - mean)) / sigma2
    return logdensity


def simulate_data(n, p, s, sb, sw, seed = 200):
    np.random.seed(seed)
    b = np.random.normal(0, sb, p)
    w = np.random.normal(0, sw, p)
    bw = np.multiply(b, w)
    X = np.zeros((n, p))
    for i in range(p):
        X[:, i] = np.random.normal(0, 1, n)
    err = np.random.normal(0, s, n)
    y = np.dot(X, bw) + err
    return X, y, b, w, bw

def print_matrix(matrix, fmt):
    print('\n'.join(['\t'.join([fmt.format(cell) for cell in row]) for row in matrix]))

nsample = 50
npred = 100
sigbtrue = 2.0
sigwtrue = 4.0
sigtrue = 50.0
X, y, btrue, wtrue, bwtrue = simulate_data(nsample, npred, sigtrue, sigbtrue, sigwtrue, seed = 300)
ypredtrue = np.dot(X, bwtrue)
r2 = variance_explained(y, ypredtrue)
bwprior = np.multiply(np.random.normal(0, sigbtrue, 1000),
                      np.random.normal(0, sigwtrue, 1000))

print(f"Fraction of variance explained by X: {r2:.2f}")
Fraction of variance explained by X: 0.77

fig = plt.figure(figsize = (12, 6))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

bwvals = np.linspace(-40, 40, 1000)
numerical_prior_1d = prod_norm_prior_pdf(bwvals, sigbtrue, sigwtrue)

ax1.scatter(ypredtrue, y)
mpl_utils.plot_diag(ax1)
ax1.text(0.02, 0.94, f"s_b = {sigbtrue}, s_w = {sigwtrue:.2f}", transform = ax1.transAxes)
#ax1.set_title("y = Xbw + e", pad = 20.0)
ax1.set_xlabel("Xbw")
ax1.set_ylabel("y")

ax2.hist(bwprior, bins = 50, density = True, alpha = 0.4, label="Simulation")
ax2.plot(bwvals, numerical_prior_1d, label = "Analytical")
ax2.legend()
ax2.set_title("Prior distribution", pad = 20.0)
ax2.set_xlabel("bw")
ax2.set_ylabel("Density")

plt.tight_layout()
plt.show()

VEB optimization

When I fix the hyperparameters $\sigma$, $\sigma_b$ and $\sigma_w$ to true values, then the variational parameters $m_b$, $s_b$, $m_w$ and $s_w$ are reasonably obtained. Further, if I only fix $\sigma$ then the results are slightly worse but still we recover the signal. However, there is significant underfitting if I try to optimize both the hyperparameters and variational parameters.

I also wanted to check what happens if I use $q(b, w) = q(b \mid w) q(w)$ instead of mean field approximation. As a quick check, I hold $b$ and $w$ fixed to their means $m_b$ and $m_w$ while updating the alternate parameter. There is significant overfitting.

Note: This is not exact, so I have to check and write the equations for $q(b, w) = q(b \mid w) q(w)$.

def get_elbo(X, XTX, XTy, yTy, s2, sb2, sw2, mb, mw, covb, covw):
    n, p = X.shape    
    elbo = cfunc(n, p, yTy, s2, sb2, sw2, covb, covw) \
            - hfunc(mb, covb, sb2) \
            - efunc(XTX, XTy, mb, mw, covb, covw, s2, sw2)
    return elbo
 
def cfunc(n, p, yTy, s2, sb2, sw2, covb, covw):
    sign, det_covb = np.linalg.slogdet(covb)
    sign, det_covw = np.linalg.slogdet(covw)
    val  =   0.5 * p 
    val += - 0.5 * p * np.log(sb2) - 0.5 * p * np.log(sw2)
    val += - 0.5 * n * np.log(2.0 * np.pi * s2)
    val +=   0.5 * det_covb + 0.5 * det_covw
    val +=   0.5 * yTy / s2
    return val

def hfunc(m, cov, s2):
    h = np.sum(np.square(m)) + cov.trace()
    return 0.5 * h / s2

def efunc(XTX, XTy, mb, mw, covb, covw, sigma2, sigmaw2):
    p = mb.shape[0]
    mmTb = np.einsum('i,j->ij', mb, mb)
    EBTXTXB = np.multiply(XTX, mmTb + covb)
    Rb = (EBTXTXB / sigma2) + (np.eye(p) / sigmaw2)
    vb = np.einsum('i,i->i', mb, XTy) / sigma2
    t1 = 0.5 * np.linalg.multi_dot([mw.T, Rb, mw])
    t2 = np.dot(mw.T, vb)
    t3 = 0.5 * np.dot(Rb, covw).trace()
    return t1 - t2 + t3


def veb_ridge_step(XTX, XTy, sigma2, sigmab2, mw, covw, use_emstep = False):
    p = XTy.shape[0]
    mmT = np.einsum('i,j->ij', mw, mw)
    if use_emstep:
        EWTXTXW = np.multiply(XTX, mmT)
    else:
        EWTXTXW = np.multiply(XTX, mmT + covw)
    covbinv = (EWTXTXW / sigma2) + (np.eye(p) / sigmab2)
    covb = np.linalg.inv(covbinv)
    mb = np.linalg.multi_dot([covb, np.diag(mw).T, XTy]) / sigma2
    return mb, covb

def get_sigma_updates(X, y, XTX, XTy, yTy, mb, mw, covb, covw, use_emstep = False):
    n, p = X.shape
    mmTw = np.einsum('i,j->ij', mw, mw)
    mmTb = np.einsum('i,j->ij', mb, mb)
    mbmwXTy = np.dot(mb, np.einsum('i,i->i', mw, XTy))
    
    if use_emstep:
        EWTXTXW = np.multiply(XTX, mmTw)
        trace_cov = np.dot(EWTXTXW, mmTb + covb).trace()
    else:
        EWTXTXW = np.multiply(XTX, mmTw + covw)
        trace_cov = np.dot(EWTXTXW, mmTb + covb).trace()
    
    sigma2 = (yTy - 2 * mbmwXTy + trace_cov) / n
    sigmab2 = np.sum(np.square(mb) + np.diag(covb)) / p
    sigmaw2 = np.sum(np.square(mw) + np.diag(covw)) / p
    return sigma2, sigmab2, sigmaw2

def veb_iridge(X, y,
         tol = 1e-8, max_iter = 10000,
         init_sigma = 1.0, init_sigmab = 1.0, init_sigmaw = 1.0,
         init_mb = 1.0, init_mw = 1.0, 
         init_sb = 1.0, init_sw = 1.0,
         use_convergence = True,
         update_sigma = True,
         update_sigmab = True,
         update_sigmaw = True,
         use_emstep = False,
         debug = True
        ):

    n, p = X.shape
    XTX = np.dot(X.T, X)
    XTy = np.dot(X.T, y)
    yTy = np.dot(y.T, y)
    elbo_path = np.zeros(max_iter + 1)
    
    # Initialize hyperparameters
    sigma   = init_sigma
    sigmab  = init_sigmab
    sigmaw  = init_sigmaw
    sigma2  = sigma  * sigma
    sigmab2 = sigmab * sigmab
    sigmaw2 = sigmaw * sigmaw
    
    # Initialize variational parameters
    covb = np.eye(p) * init_sb * init_sb
    covw = np.eye(p) * init_sw * init_sw
    mb   = np.repeat(init_mb, p)
    mw   = np.repeat(init_mw, p)
    
    niter = 0
    elbo_path[0] = -np.inf
    for itn in range(1, max_iter + 1):
        '''
        Update
        '''
        mb, covb = veb_ridge_step(XTX, XTy, sigma2, sigmab2, mw, covw, use_emstep)
        mw, covw = veb_ridge_step(XTX, XTy, sigma2, sigmaw2, mb, covb, use_emstep)
        
        if update_sigma or update_sigmab or update_sigmaw:
            _sigma2, _sigmab2, _sigmaw2 = get_sigma_updates(X, y, XTX, XTy, yTy, mb, mw, covb, covw, use_emstep)
            
        if update_sigma:  sigma2  = _sigma2
        if update_sigmab: sigmab2 = _sigmab2
        if update_sigmaw: sigmaw2 = _sigmaw2

        sigma  = np.sqrt(sigma2)
        sigmab = np.sqrt(sigmab2)
        sigmaw = np.sqrt(sigmaw2)
            
        if debug:
            print(f"Iteration {itn}")
            print(sigma2, sigmab2, sigmaw2)
        '''
        Convergence
        '''
        niter += 1
        elbo_path[itn] = get_elbo(X, XTX, XTy, yTy, sigma2, sigmab2, sigmaw2, mb, mw, covb, covw)
        if use_convergence:
            if elbo_path[itn] - elbo_path[itn - 1] < tol: break

    return mb, mw, covb, covw, niter, elbo_path[:niter + 1], sigma, sigmab, sigmaw

vebres = dict()

vebres['fix_sall'] = veb_iridge(X, y, init_sigma = sigtrue, init_sigmab = sigbtrue, init_sigmaw = sigwtrue,
                          update_sigma = False, update_sigmab = False, update_sigmaw = False, debug = False)

vebres['fix_sbsw'] = veb_iridge(X, y, init_sigma = sigtrue, init_sigmab = sigbtrue, init_sigmaw = sigwtrue,
                          update_sigma = True, update_sigmab = False, update_sigmaw = False, debug = False)

vebres['fix_s']    = veb_iridge(X, y, init_sigma = sigtrue, init_sigmab = sigbtrue, init_sigmaw = sigwtrue,
                          update_sigma = False, update_sigmab = True, update_sigmaw = True, debug = False)

vebres['fix_none'] = veb_iridge(X, y, init_sigma = sigtrue, init_sigmab = sigbtrue, init_sigmaw = sigwtrue,
                          update_sigma = True, update_sigmab = True, update_sigmaw = True, debug = False)

vebres['fix_b/w']       = veb_iridge(X, y, init_sigma = sigtrue, init_sigmab = sigbtrue, init_sigmaw = sigwtrue,
                          update_sigma = True, update_sigmab = True, update_sigmaw = True, debug = False,
                          use_emstep = True)

def list_sigmas_vebres(res):
    sigma = res[6]
    sigmab = res[7]
    sigmaw = res[8]
    return [sigma, sigmab, sigmaw]

def list_params_vebres(res):
    niter = res[4]
    elbo = res[5][-1]
    return [niter, elbo]

data = [[sigtrue,  sigbtrue,  sigwtrue] + [0, 0]]
rownames = ['True']
for key, val in vebres.items():
    data.append(list_sigmas_vebres(val) + list_params_vebres(val))
    rownames.append(key)

colnames = ['s', 'sb', 'sw', 'niter', 'ELBO']
#rownames = ['True', 'VEB-fix', 'VEB-full', 'VEB-fix-s', 'VEB-fix-sb-sw', 'VEB-high-s']
df = pd.DataFrame.from_records(data, columns = colnames, index = rownames)
#df.style.format("{:.3f}")
df
s sb sw niter ELBO
True 50.000000 2.000000 4.000000 0 0.000000e+00
fix_sall 50.000000 2.000000 4.000000 140 -2.109103e+02
fix_sbsw 98.965709 2.000000 4.000000 2 -3.190882e+02
fix_s 50.000000 1.671161 2.725174 92 -2.047158e+02
fix_none 95.628828 1.934922 2.915587 2 -3.120212e+02
fix_b/w 0.097264 2.345863 4.170464 1128 3.304583e+07

def plot_predictions(ax, X, y, btrue, wtrue, res, key):
    mb = res[0]
    mw = res[1]
    bwpred = np.multiply(mb, mw)
    ypred = np.dot(X, bwpred)
    ypredtrue = np.dot(X, np.multiply(btrue, wtrue))
    ax.scatter(y, ypred)
    mpl_utils.plot_diag(ax)
    ax.text(0.1, 0.9, f"{key}", transform = ax.transAxes)
    ax.set_xlabel('y')
    ax.set_ylabel('VEB ypred')
    return
    

fig = plt.figure(figsize = (18, 12))

i = 0
for key, res in vebres.items():
    ax = fig.add_subplot(2, 3, i+1)
    plot_predictions(ax, X, y, btrue, wtrue, res, key)
    i += 1

plt.tight_layout()
plt.show()