Comparison of prediction accuracy of EM-VAMP, Mr.ASH and EBMR
- About
- Importing packages and DSC results
- High dimension setting (p > n)
- Low dimension setting (p < n)
- Convergence of EM-VAMP and EM-VAMP (ash)
About
Here, I compare the prediction accuracy of EM-VAMP and EM-iRidge with existing penalized regression method. The simulation scenarios were used earlier for comparing few well-known penalized regression methods.
-
EM-VAMP. Proposed by Fletcher and Schniter, 2017, this algorithm combines Vector Approximate Message Passing (VAMP) and Expectation Maximization and is well suited for sparse linear regression. I used vampyre for the implementation.
-
EM-VAMP (ash). Instead of a spike-and-slab prior, I use the adaptive shrinkage (ash) prior for the coefficients. See, for example, here for further details.
-
EM-iRidge. Proposed by Matthew Stephens, 2020, this algorithm uses iterative ridge regression to solve linear regression where the prior of the coefficients is given by a product of two normal distributions, which has sparsity inducing properties. I implemented an Expectation Maximization algorithm for iRidge in ebmrPy.
-
EBMR (DExp). Proposed by Matthew Stephens, 2020, this algorithm uses Empirical Bayes regression, where the prior variance of the coefficient depends hierarchically on another distribution.
The simulation pipeline is implemented using Dynamic Statistical Comparisons (DSC).
Importing packages and DSC results
Non-standard packages include DSC and PyMir. The simulation repository needs to be in path for importing some of the utilities.
import pandas as pd
import numpy as np
import math
import os
import sys
import collections
srcdir = "/home/saikat/Documents/work/ebmr/simulation/eb-linreg-dsc"
sys.path.append(os.path.join(srcdir, "analysis"))
import dscrutils2py as dscrutils
import methodprops
import methodplots
import convergence_plots as convplots
import dsc_extract
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from pymir import mpl_stylesheet
from pymir import mpl_utils
from pymir import pd_utils
mpl_stylesheet.banskt_presentation()
I have run the simulations using the following settings.
# Dimensions (n, p)
highdims = (100, 200)
lowdims = (500, 200)
# fraction of non-zero predictors, sfrac = p_causal / p
sfracs = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0]
# PVE
pve_list = [0.5, 0.95]
# rho
rho_list = [0, 0.95]
# Output directory
dsc_outdir = os.path.join(srcdir, "dsc/dsc_result")
Read the results of the simulation and store it in a DataFrame
targets = ["simulate", "simulate.dims", "simulate.se", "simulate.rho",
"simulate.sfrac", "simulate.pve", "fit", "fit.DSC_TIME", "mse.err"]
dscout = dscrutils.dscquery(dsc_outdir, targets)
dscout['score1'] = np.sqrt(dscout['mse.err'])/dscout['simulate.se']
dscout
Select the methods to be displayed in the figures.
whichmethods = [#"l0learn",
"lasso",
"ridge",
"elastic_net",
#"scad",
#"mcp",
#"blasso",
#"bayesb",
#"susie",
#"varbvs",
#"varbvsmix",
"mr_ash",
"em_vamp",
"em_vamp_ash",
"em_iridge",
"ebmr_lasso",
#"ebmr_ash",
]
highdim_condition = [f"$(simulate.dims) == '({highdims[0]},{highdims[1]})'"]
resdf1 = pd_utils.select_dfrows(dscout, highdim_condition)
methodplots.create_figure_prediction_error(whichmethods, resdf1, highdims,
rho_list, pve_list, sfracs, use_median = True)
lowdim_condition = [f"$(simulate.dims) == '({lowdims[0]},{lowdims[1]})'"]
resdf2 = pd_utils.select_dfrows(dscout, lowdim_condition)
methodplots.create_figure_prediction_error(whichmethods, resdf2, lowdims,
rho_list, pve_list, sfracs, use_median = True)
Convergence of EM-VAMP and EM-VAMP (ash)
EM-VAMP fails to converge in some simulations. Here, I look at the distribution of prediction errors of EM-VAMP over all the 20 simulations at different settings. The failures are most prominent at high dimension and high correlation. With the adaptive shrinkage prior, the convergence is also compromised at low dimension and high correlation setting.
convplots.create_single_method_score_distribution_plot(dscout, "em_vamp",
[highdims, lowdims], rho_list, pve_list, sfracs, 'score1')
convplots.create_single_method_score_distribution_plot(dscout, "em_vamp_ash",
[highdims, lowdims], rho_list, pve_list, sfracs, 'score1')
The prediction error on a test set over the iterations (c.f. Fig. 2 of Fletcher and Schniter, 2017) show diverging and oscillating behavior.
dim = (100, 200)
pve = 0.95
sfrac = sfracs[2]
method = "em_vamp"
allscores = dict()
for i, rho in enumerate(rho_list):
allscores[rho] = dsc_extract.emvamp_mse_hist(dsc_outdir, method, dim, sfrac, pve, rho)
convplots.create_single_setting_score_evolution_plot(allscores, method, dim, sfrac, pve, rho_list)
dim = (100, 200)
pve = 0.95
sfrac = sfracs[2]
method = "em_vamp_ash"
allscores = dict()
for i, rho in enumerate(rho_list):
allscores[rho] = dsc_extract.emvamp_mse_hist(dsc_outdir, method, dim, sfrac, pve, rho)
convplots.create_single_setting_score_evolution_plot(allscores, method, dim, sfrac, pve, rho_list)