About

Here, I look at the performance of EM-VAMP on changepoint simulation.

%load_ext autoreload
%autoreload 2

import numpy as np
import os
import sys
import vampyre
import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(dpi = 72, splinecolor = "black")

srcdir = "/home/saikat/Documents/work/ebmr/simulation/eb-linreg-dsc/dsc"
sys.path.append(os.path.join(srcdir, "functions"))
from fit import fit_em_vamp
import simulate
n = 100
p = 200
s = 1
snr = 10
signal = "fixed"
bfix = 8

X, y, Xtest, ytest, beta, se = simulate.changepoint_predictors(n, p, s, snr, 
                                                               signal = signal, bfix = bfix)

fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.plot(np.arange(n), np.dot(X, beta), label="Xb")
ax1.scatter(np.arange(n), y, edgecolor = 'black', facecolor='white', label="Xb + e")
ax1.legend()
ax1.set_xlabel("Sample index")
ax1.set_ylabel("y")
plt.show()
def get_rmse(X, y, bhist):
    niter = len(bhist)
    rmse  = np.zeros(niter)
    n, p = X.shape
    for it in range(niter):
        bhati    = bhist[it]
        ypred    = np.dot(Xtest, bhati[1:]) + bhati[0]
        rmse[it] = np.sqrt(np.mean((ytest.reshape(n,1) - ypred)**2))
    return rmse

Run the EM-VAMP solver

emvamp_bhist, emvamp_mu, emvamp_beta = fit_em_vamp(X, y)

Convergence check on input data

emvamp_rmse = get_rmse(X, y, emvamp_bhist)

fig = plt.figure(figsize = (6, 6))
ax1 = fig.add_subplot(111)

xvals = np.log10(1 + np.arange(emvamp_rmse.shape[0]))

ax1.plot(xvals, np.log10(emvamp_rmse / se), 's-', alpha = 0.5)
        
#ax1.legend(['EM-VAMP (ash)', 'EM-VAMP'])
ax1.set_ylabel(r"Prediction Error (RMSE / $\sigma$)")
ax1.set_xlabel(r"Number of iterations")

ax1.tick_params(labelcolor = "#333333")
# Magic ticks Switch off if not using PyMir.
mpl_utils.set_yticks(ax1, kmin = 4, kmax = 6, scale = 'log10', spacing = 'log10')
mpl_utils.set_xticks(ax1, kmin = 4, kmax = 6, scale = 'log10', spacing = 'log10')

plt.show()

Prediction on test data

vamp_ypred = np.dot(X, emvamp_beta.reshape(-1))

fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.scatter(np.arange(n), y, edgecolor = 'black', facecolor='white')
ax1.plot(np.arange(n), vamp_ypred)
ax1.set_xlabel("Sample index")
ax1.set_ylabel("y")

plt.show()