Demonstration of EM-VAMP with changepoint simulation
%load_ext autoreload
%autoreload 2
import numpy as np
import os
import sys
import vampyre
import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(dpi = 72, splinecolor = "black")
srcdir = "/home/saikat/Documents/work/ebmr/simulation/eb-linreg-dsc/dsc"
sys.path.append(os.path.join(srcdir, "functions"))
from fit import fit_em_vamp
import simulate
n = 100
p = 200
s = 1
snr = 10
signal = "fixed"
bfix = 8
X, y, Xtest, ytest, beta, se = simulate.changepoint_predictors(n, p, s, snr,
signal = signal, bfix = bfix)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(np.arange(n), np.dot(X, beta), label="Xb")
ax1.scatter(np.arange(n), y, edgecolor = 'black', facecolor='white', label="Xb + e")
ax1.legend()
ax1.set_xlabel("Sample index")
ax1.set_ylabel("y")
plt.show()
def get_rmse(X, y, bhist):
niter = len(bhist)
rmse = np.zeros(niter)
n, p = X.shape
for it in range(niter):
bhati = bhist[it]
ypred = np.dot(Xtest, bhati[1:]) + bhati[0]
rmse[it] = np.sqrt(np.mean((ytest.reshape(n,1) - ypred)**2))
return rmse
emvamp_bhist, emvamp_mu, emvamp_beta = fit_em_vamp(X, y)
emvamp_rmse = get_rmse(X, y, emvamp_bhist)
fig = plt.figure(figsize = (6, 6))
ax1 = fig.add_subplot(111)
xvals = np.log10(1 + np.arange(emvamp_rmse.shape[0]))
ax1.plot(xvals, np.log10(emvamp_rmse / se), 's-', alpha = 0.5)
#ax1.legend(['EM-VAMP (ash)', 'EM-VAMP'])
ax1.set_ylabel(r"Prediction Error (RMSE / $\sigma$)")
ax1.set_xlabel(r"Number of iterations")
ax1.tick_params(labelcolor = "#333333")
# Magic ticks Switch off if not using PyMir.
mpl_utils.set_yticks(ax1, kmin = 4, kmax = 6, scale = 'log10', spacing = 'log10')
mpl_utils.set_xticks(ax1, kmin = 4, kmax = 6, scale = 'log10', spacing = 'log10')
plt.show()
vamp_ypred = np.dot(X, emvamp_beta.reshape(-1))
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(np.arange(n), y, edgecolor = 'black', facecolor='white')
ax1.plot(np.arange(n), vamp_ypred)
ax1.set_xlabel("Sample index")
ax1.set_ylabel("y")
plt.show()