Multiple regression with product of normals
About
I am trying to understand the problem in the VEB implementation of multiple regression with product of normals. I checked the analysis with single predictor ($p=1$) (link) and with two predictors ($p=2$) (link). Here, I will check with $n$ predictors with effect sizes sampled from a known prior.
import numpy as np
import pandas as pd
from scipy import linalg as sc_linalg
from scipy import special as sc_special
import matplotlib.pyplot as plt
import mpl_stylesheet
import mpl_utils
from matplotlib import cm
from matplotlib import ticker as plticker
from mpl_toolkits.axes_grid1 import make_axes_locatable
mpl_stylesheet.banskt_presentation(fontfamily = 'latex-clearsans', fontsize = 18, colors = 'banskt', dpi = 72)
Toy model
I defined a toy model with 50 predictors and 100 samples. I sampled the predictor from the standard normal distribution $\mathcal{N}(0, 1)$. The model used is described in the writeup.
I chose $\sigma = 50.0$, $\sigma_b = 2.0$ and $\sigma_w = 4.0$.
def variance_explained(y, ypred):
ss_err = np.sum(np.square(y - ypred))
ss_tot = np.sum(np.square(y - np.mean(y)))
r2 = 1 - (ss_err / ss_tot)
return r2
def prod_norm_prior_pdf(z, s1, s2):
x = np.abs(z) / s1 / s2
prob = sc_special.kn(0, x) / (np.pi * s1 * s2)
return prob
def normal_logdensity_onesample(z, m, s2):
logdensity = - 0.5 * np.log(2 * np.pi * s2) \
- 0.5 * (z - m) * (z - m) / s2
return logdensity
def normal_logdensity(y, mean, sigma2):
n = y.shape[0]
logdensity = - 0.5 * n * np.log(2 * np.pi * sigma2) \
- 0.5 * np.sum(np.square(y - mean)) / sigma2
return logdensity
def simulate_data(n, p, s, sb, sw, seed = 200):
np.random.seed(seed)
b = np.random.normal(0, sb, p)
w = np.random.normal(0, sw, p)
bw = np.multiply(b, w)
X = np.zeros((n, p))
for i in range(p):
X[:, i] = np.random.normal(0, 1, n)
err = np.random.normal(0, s, n)
y = np.dot(X, bw) + err
return X, y, b, w, bw
def print_matrix(matrix, fmt):
print('\n'.join(['\t'.join([fmt.format(cell) for cell in row]) for row in matrix]))
nsample = 50
npred = 100
sigbtrue = 2.0
sigwtrue = 4.0
sigtrue = 50.0
X, y, btrue, wtrue, bwtrue = simulate_data(nsample, npred, sigtrue, sigbtrue, sigwtrue, seed = 300)
ypredtrue = np.dot(X, bwtrue)
r2 = variance_explained(y, ypredtrue)
bwprior = np.multiply(np.random.normal(0, sigbtrue, 1000),
np.random.normal(0, sigwtrue, 1000))
print(f"Fraction of variance explained by X: {r2:.2f}")
fig = plt.figure(figsize = (12, 6))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
bwvals = np.linspace(-40, 40, 1000)
numerical_prior_1d = prod_norm_prior_pdf(bwvals, sigbtrue, sigwtrue)
ax1.scatter(ypredtrue, y)
mpl_utils.plot_diag(ax1)
ax1.text(0.02, 0.94, f"s_b = {sigbtrue}, s_w = {sigwtrue:.2f}", transform = ax1.transAxes)
#ax1.set_title("y = Xbw + e", pad = 20.0)
ax1.set_xlabel("Xbw")
ax1.set_ylabel("y")
ax2.hist(bwprior, bins = 50, density = True, alpha = 0.4, label="Simulation")
ax2.plot(bwvals, numerical_prior_1d, label = "Analytical")
ax2.legend()
ax2.set_title("Prior distribution", pad = 20.0)
ax2.set_xlabel("bw")
ax2.set_ylabel("Density")
plt.tight_layout()
plt.show()
VEB optimization
When I fix the hyperparameters $\sigma$, $\sigma_b$ and $\sigma_w$ to true values, then the variational parameters $m_b$, $s_b$, $m_w$ and $s_w$ are reasonably obtained. Further, if I only fix $\sigma$ then the results are slightly worse but still we recover the signal. However, there is significant underfitting if I try to optimize both the hyperparameters and variational parameters.
I also wanted to check what happens if I use $q(b, w) = q(b \mid w) q(w)$ instead of mean field approximation. As a quick check, I hold $b$ and $w$ fixed to their means $m_b$ and $m_w$ while updating the alternate parameter. There is significant overfitting.
Note: This is not exact, so I have to check and write the equations for $q(b, w) = q(b \mid w) q(w)$.
def get_elbo(X, XTX, XTy, yTy, s2, sb2, sw2, mb, mw, covb, covw):
n, p = X.shape
elbo = cfunc(n, p, yTy, s2, sb2, sw2, covb, covw) \
- hfunc(mb, covb, sb2) \
- efunc(XTX, XTy, mb, mw, covb, covw, s2, sw2)
return elbo
def cfunc(n, p, yTy, s2, sb2, sw2, covb, covw):
sign, det_covb = np.linalg.slogdet(covb)
sign, det_covw = np.linalg.slogdet(covw)
val = 0.5 * p
val += - 0.5 * p * np.log(sb2) - 0.5 * p * np.log(sw2)
val += - 0.5 * n * np.log(2.0 * np.pi * s2)
val += 0.5 * det_covb + 0.5 * det_covw
val += 0.5 * yTy / s2
return val
def hfunc(m, cov, s2):
h = np.sum(np.square(m)) + cov.trace()
return 0.5 * h / s2
def efunc(XTX, XTy, mb, mw, covb, covw, sigma2, sigmaw2):
p = mb.shape[0]
mmTb = np.einsum('i,j->ij', mb, mb)
EBTXTXB = np.multiply(XTX, mmTb + covb)
Rb = (EBTXTXB / sigma2) + (np.eye(p) / sigmaw2)
vb = np.einsum('i,i->i', mb, XTy) / sigma2
t1 = 0.5 * np.linalg.multi_dot([mw.T, Rb, mw])
t2 = np.dot(mw.T, vb)
t3 = 0.5 * np.dot(Rb, covw).trace()
return t1 - t2 + t3
def veb_ridge_step(XTX, XTy, sigma2, sigmab2, mw, covw, use_emstep = False):
p = XTy.shape[0]
mmT = np.einsum('i,j->ij', mw, mw)
if use_emstep:
EWTXTXW = np.multiply(XTX, mmT)
else:
EWTXTXW = np.multiply(XTX, mmT + covw)
covbinv = (EWTXTXW / sigma2) + (np.eye(p) / sigmab2)
covb = np.linalg.inv(covbinv)
mb = np.linalg.multi_dot([covb, np.diag(mw).T, XTy]) / sigma2
return mb, covb
def get_sigma_updates(X, y, XTX, XTy, yTy, mb, mw, covb, covw, use_emstep = False):
n, p = X.shape
mmTw = np.einsum('i,j->ij', mw, mw)
mmTb = np.einsum('i,j->ij', mb, mb)
mbmwXTy = np.dot(mb, np.einsum('i,i->i', mw, XTy))
if use_emstep:
EWTXTXW = np.multiply(XTX, mmTw)
trace_cov = np.dot(EWTXTXW, mmTb + covb).trace()
else:
EWTXTXW = np.multiply(XTX, mmTw + covw)
trace_cov = np.dot(EWTXTXW, mmTb + covb).trace()
sigma2 = (yTy - 2 * mbmwXTy + trace_cov) / n
sigmab2 = np.sum(np.square(mb) + np.diag(covb)) / p
sigmaw2 = np.sum(np.square(mw) + np.diag(covw)) / p
return sigma2, sigmab2, sigmaw2
def veb_iridge(X, y,
tol = 1e-8, max_iter = 10000,
init_sigma = 1.0, init_sigmab = 1.0, init_sigmaw = 1.0,
init_mb = 1.0, init_mw = 1.0,
init_sb = 1.0, init_sw = 1.0,
use_convergence = True,
update_sigma = True,
update_sigmab = True,
update_sigmaw = True,
use_emstep = False,
debug = True
):
n, p = X.shape
XTX = np.dot(X.T, X)
XTy = np.dot(X.T, y)
yTy = np.dot(y.T, y)
elbo_path = np.zeros(max_iter + 1)
# Initialize hyperparameters
sigma = init_sigma
sigmab = init_sigmab
sigmaw = init_sigmaw
sigma2 = sigma * sigma
sigmab2 = sigmab * sigmab
sigmaw2 = sigmaw * sigmaw
# Initialize variational parameters
covb = np.eye(p) * init_sb * init_sb
covw = np.eye(p) * init_sw * init_sw
mb = np.repeat(init_mb, p)
mw = np.repeat(init_mw, p)
niter = 0
elbo_path[0] = -np.inf
for itn in range(1, max_iter + 1):
'''
Update
'''
mb, covb = veb_ridge_step(XTX, XTy, sigma2, sigmab2, mw, covw, use_emstep)
mw, covw = veb_ridge_step(XTX, XTy, sigma2, sigmaw2, mb, covb, use_emstep)
if update_sigma or update_sigmab or update_sigmaw:
_sigma2, _sigmab2, _sigmaw2 = get_sigma_updates(X, y, XTX, XTy, yTy, mb, mw, covb, covw, use_emstep)
if update_sigma: sigma2 = _sigma2
if update_sigmab: sigmab2 = _sigmab2
if update_sigmaw: sigmaw2 = _sigmaw2
sigma = np.sqrt(sigma2)
sigmab = np.sqrt(sigmab2)
sigmaw = np.sqrt(sigmaw2)
if debug:
print(f"Iteration {itn}")
print(sigma2, sigmab2, sigmaw2)
'''
Convergence
'''
niter += 1
elbo_path[itn] = get_elbo(X, XTX, XTy, yTy, sigma2, sigmab2, sigmaw2, mb, mw, covb, covw)
if use_convergence:
if elbo_path[itn] - elbo_path[itn - 1] < tol: break
return mb, mw, covb, covw, niter, elbo_path[:niter + 1], sigma, sigmab, sigmaw
vebres = dict()
vebres['fix_sall'] = veb_iridge(X, y, init_sigma = sigtrue, init_sigmab = sigbtrue, init_sigmaw = sigwtrue,
update_sigma = False, update_sigmab = False, update_sigmaw = False, debug = False)
vebres['fix_sbsw'] = veb_iridge(X, y, init_sigma = sigtrue, init_sigmab = sigbtrue, init_sigmaw = sigwtrue,
update_sigma = True, update_sigmab = False, update_sigmaw = False, debug = False)
vebres['fix_s'] = veb_iridge(X, y, init_sigma = sigtrue, init_sigmab = sigbtrue, init_sigmaw = sigwtrue,
update_sigma = False, update_sigmab = True, update_sigmaw = True, debug = False)
vebres['fix_none'] = veb_iridge(X, y, init_sigma = sigtrue, init_sigmab = sigbtrue, init_sigmaw = sigwtrue,
update_sigma = True, update_sigmab = True, update_sigmaw = True, debug = False)
vebres['fix_b/w'] = veb_iridge(X, y, init_sigma = sigtrue, init_sigmab = sigbtrue, init_sigmaw = sigwtrue,
update_sigma = True, update_sigmab = True, update_sigmaw = True, debug = False,
use_emstep = True)
def list_sigmas_vebres(res):
sigma = res[6]
sigmab = res[7]
sigmaw = res[8]
return [sigma, sigmab, sigmaw]
def list_params_vebres(res):
niter = res[4]
elbo = res[5][-1]
return [niter, elbo]
data = [[sigtrue, sigbtrue, sigwtrue] + [0, 0]]
rownames = ['True']
for key, val in vebres.items():
data.append(list_sigmas_vebres(val) + list_params_vebres(val))
rownames.append(key)
colnames = ['s', 'sb', 'sw', 'niter', 'ELBO']
#rownames = ['True', 'VEB-fix', 'VEB-full', 'VEB-fix-s', 'VEB-fix-sb-sw', 'VEB-high-s']
df = pd.DataFrame.from_records(data, columns = colnames, index = rownames)
#df.style.format("{:.3f}")
df
def plot_predictions(ax, X, y, btrue, wtrue, res, key):
mb = res[0]
mw = res[1]
bwpred = np.multiply(mb, mw)
ypred = np.dot(X, bwpred)
ypredtrue = np.dot(X, np.multiply(btrue, wtrue))
ax.scatter(y, ypred)
mpl_utils.plot_diag(ax)
ax.text(0.1, 0.9, f"{key}", transform = ax.transAxes)
ax.set_xlabel('y')
ax.set_ylabel('VEB ypred')
return
fig = plt.figure(figsize = (18, 12))
i = 0
for key, res in vebres.items():
ax = fig.add_subplot(2, 3, i+1)
plot_predictions(ax, X, y, btrue, wtrue, res, key)
i += 1
plt.tight_layout()
plt.show()