Comparison of different methods for changepoint estimation
- About
- Importing packages
- Simulations
- Trend-filtering with single changepoint / knot (s = 1)
- Trend-filtering with four changepoints / knots (s = 4)
- Dependence of prediction errors on the number of knots and trend-filtering order
About
Here, we consider a class of non-parametric regression, particularly
$y_i = \mu_i + e_i$ with $i = 1, \cdots, n,$
where the goal is to estimate the underlying mean $\mu_i$ under the assumption that it varies in a spatially structured way. For a comprehensive discussion of this problem, see Tibshirani, 2014. One very simple way to capture spatial structure in $\mu$ is to model it as a (sparse) linear combination of step functions:
$\mathbf{\mu} = \mathbf{X}\mathbf{b}$
where the $j$th column of $\mathbf{X}$ is the step function with a step at $j$; that is $x_{ij} = 0$ for $i \lt j$ and $1$ for $i \ge j$. The $j$th element of $\mathbf{b}$ therefore determines the change in the mean $\mu_j - \mu_{j + 1}$, and an assumption that $\mathbf{b}$ is sparse encapsulates an assumption that $\mu$ is spatially structured (indeed, piecewise constant). This very simple approach is essentially 0th-order trend filtering. Higher-order trend filtering can be similarly implemented using different basis function $\mathbf{X}$. For any order $k$, the basis function can be defined as,
$ x_{ij} = \begin{cases} i^{j - 1} / n^{j - 1}, & \text{for $i = 1, \cdots, n$, $j = 1, \cdots, k+1$} \\ 0, & \text{for $i \le j - l$, $j \ge k + 2$,} \\ (i - j + l)^{k} / n^k & \text{for $i \gt j - l$, $j \ge k + 2$} \end{cases} $
where $ l = \begin{cases} k / 2 & \text{if $k \gt 0$ is even} \\ (k + 1) / 2 & \text{if $k \gt 0$ is odd} \end{cases} $
In trend-filtering language, the number of non-zero elements ($s$) is called the "knots" of $\mathbf{b}$, and the degrees of freedom ($d$) is defined as,
$ d = s + k + 1$.
Importing packages
Non-standard packages include DSC and PyMir. The simulation repository needs to be in path for importing some of the utilities.
import pandas as pd
import numpy as np
#import math
import os
import sys
import collections
srcdir = "/home/saikat/Documents/work/ebmr/simulation/eb-linreg-dsc"
sys.path.append(os.path.join(srcdir, "analysis"))
import dscrutils2py as dscrutils
import methodprops
import methodplots
import dsc_extract
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerPatch
from pymir import mpl_stylesheet
from pymir import mpl_utils
from pymir import pd_utils
mpl_stylesheet.banskt_presentation()
Simulations
The simulation pipeline is implemented using Dynamic Statistical Comparisons (DSC).
In the current benchmark, we look at simple examples of trend-filtering for constant, linear and quadratic orders ($k = 0, 1$ and $2$). Although the trend filtering estimates are only defined at the discrete inputs, we use linear interpolation to extend the estimates for visualization purposes.
Finally, we also look at the mean squared error on a test data over a range of "knots", $s = 1, 2, 4, 6, 8, 10$ and $15$.
For fitting the regression methods, we assume that the basis is known, which might not be the case for real data.
dsc_outdir = os.path.join(srcdir, "dsc/dsc_result_changepoint")
dims = (500, 500)
targets = ["changepoint", "changepoint.dims", "changepoint.se", "changepoint.sfix",
"changepoint.basis_k", "changepoint.snr",
"fit_cpt", "fit_cpt.DSC_TIME", "mse.err"]
methods = ["lasso", "elastic_net", "susie", "mr_ash", "ebmr_lasso", "ebmr_ash", "ebmr_ashR", "em_iridge"]
conditions = None
orders = [0, 1, 2]
knots = [1, 2, 4, 6, 8, 10, 15, 20]
groups = ["fit:"]
dscout = dscrutils.dscquery(dsc_outdir, targets, groups = groups)
dscout['score1'] = np.sqrt(dscout['mse.err'])/dscout['changepoint.se']
class HandlerSquare(HandlerPatch):
def create_artists(self, legend, orig_handle,
xdescent, ydescent, width, height, fontsize, trans):
#center = xdescent + 0.5 * (width - height), ydescent + 0.5 * (width - height)
center = xdescent, ydescent - 2.0#+ 0.5 * (width - height)
p = mpatches.Rectangle(xy=center, width=height, # width = height for square
height=height, angle=0.0)
self.update_prop(p, orig_handle, legend)
p.set_transform(trans)
return [p]
def get_mse_err_dsc(dscout, colname, method, orders, s = 1):
conditions = [f"$(fit_cpt) == {method}", f"$(changepoint.sfix) == {s}"]
mdata = pd_utils.select_dfrows(dscout, conditions)
mselist = list()
for order in orders:
mse = pd_utils.select_dfrows(mdata, [f"$(changepoint.basis_k) == {order}"])[colname].to_numpy()
mselist.append(np.log10(mse))
return mselist
def trend_filter_example_plot(dsc_outdir, dscout, methods, sfix, orders, plot_iter, colname = 'score1'):
figw = 12
figh = 8
nrow = 2
ncol = 3
wspace = 0.4
hspace = 0.5
# Main plot structure
mpl_stylesheet.banskt_presentation(splinecolor = 'black', fontsize = 12, splinewidth = 1, dpi = 300)
fig = plt.figure(figsize = (figw, figh))
gs = gridspec.GridSpec(nrow, ncol)
gs.update(wspace = wspace, hspace = hspace)
# Examples
axlist = list()
mhandles = list()
mlabels = list()
for i, order in enumerate(orders):
ax = fig.add_subplot(gs[0, i])
X, y, Xtest, ytest, beta, se, ypred, b0pred, b1pred \
= dsc_extract.changepoint_predictions(dsc_outdir, methods,
order = order, sfix = sfix, dsc_iter = plot_iter[order])
ax.scatter(np.arange(y.shape[0]), ytest, color = '#A7A7A7', facecolor = 'None', s = 10)
for method in methods:
pm = methodprops.plot_metainfo()[method]
ax.plot(np.arange(y.shape[0]), ypred[method], color = pm.color, label = pm.label)
if i == 0:
mhandles.append(mpatches.Rectangle((0,0), 20, 20, **{'color': pm.color}))
mlabels.append(pm.label)
axlist.append(ax)
# Square errors boxplot
axc = fig.add_subplot(gs[1, :ncol - 1])
norder = len(orders)
xticks = [j*(norder+1) + i for j in range(len(methods)) for i in range(norder)]
xticklabels = [i for j in range(len(methods)) for i in range(norder)]
for j, method in enumerate(methods):
pm = methodprops.plot_metainfo()[method]
mse = get_mse_err_dsc(dscout, colname, method, orders, s = sfix)
xpos = xticks[j*norder:(j+1)*norder]
boxprops = dict(linewidth = 2, color = pm.color, facecolor = pm.color, alpha = 0.6)
medianprops = dict(linewidth = 2, color = pm.color)
whiskerprops = dict(color = pm.color)
flierprops = dict(marker = 'o', markerfacecolor = pm.color, markersize=4,
markeredgewidth = 0, markeredgecolor = pm.color)
boxp = axc.boxplot(mse, positions = xpos, widths = 0.6, showfliers = True, showcaps = True,
patch_artist = True, boxprops = boxprops,
medianprops = medianprops, whiskerprops = whiskerprops, flierprops = flierprops)
axc.set_xticks(xticks, minor = False)
axc.set_xticklabels(xticklabels)
mpl_utils.set_yticks(axc, scale = 'log10', kmin = 3, kmax = 4, spacing = 'linear')
# Legend
axl = fig.add_subplot(gs[1, ncol - 1])
axl.tick_params(bottom = False, left = False, labelbottom = False, labelleft = False)
legendtitle = 'Methods'
mhandler_map = dict()
for x in mhandles:
mhandler_map[x] = HandlerSquare()
legend1 = axl.legend(handles = mhandles, labels = mlabels, handler_map = mhandler_map, ncol = 2,
loc = 'upper left', bbox_to_anchor = (0, 1.0), frameon = False, title = legendtitle)
mpl_utils.decorate_axes(axl, hide = ["all"], ticklimits = False)
# Axes labels
axlist[0].set_ylabel(r"$y$")
axlist[1].set_xlabel(r"Sample index", labelpad = 10)
axc.set_ylabel(r"RMSE / $\sigma$")
axc.set_xlabel("Trend-filtering order")
plt.show()
plot_iter = {0: 2, 1: 1, 2: 10}
trend_filter_example_plot(dsc_outdir, dscout, methods, 1, orders, plot_iter, colname = 'score1')
plot_iter = {0: 1, 1: 11, 2: 16}
trend_filter_example_plot(dsc_outdir, dscout, methods, 4, orders, plot_iter, colname = 'score1')
def single_plot_score_methods(ax, resdf, colname, methods, knots, order, xscale, yscale, use_median = False):
for method in methods:
score = [0 for x in knots]
mconditions = [f"$(fit_cpt) == {method}"]
mconditions += [f"$(changepoint.basis_k) == {order}"]
for i, sfix in enumerate(knots):
sfix_condition = [f"$(changepoint.sfix) == {sfix}"]
dfselect = pd_utils.select_dfrows(resdf, mconditions + sfix_condition)
scores = dfselect[colname].to_numpy()
if use_median:
score[i] = np.median(scores[~np.isnan(scores)])
else:
score[i] = np.mean(scores[~np.isnan(scores)])
# Plot knots vs score
pm = methodprops.plot_metainfo()[method]
xx = mpl_utils.scale_array(knots, xscale)
yy = mpl_utils.scale_array(score, yscale)
ax.plot(xx, yy, label = pm.label,
color = pm.color, lw = pm.linewidth / 2, ls = pm.linestyle,
marker = pm.marker, ms = pm.size / 1.2, mec = pm.color, mfc = pm.facecolor,
mew = pm.linewidth, zorder = pm.zorder
)
return
figw = 12
figh = 12
nrow = 2
ncol = 2
wspace = 0.5
hspace = 0.5
xscale = 'log10'
yscale = 'log10'
mtitles = {0: "Constant basis (k = 0)",
1: "Linear basis (k = 1)",
2: "Quadratic basis (k = 2)"}
# Main plot structure
mpl_stylesheet.banskt_presentation(splinecolor = 'black', fontsize = 12, splinewidth = 1, dpi = 300)
fig = plt.figure(figsize = (figw, figh))
gs = gridspec.GridSpec(nrow, ncol)
gs.update(wspace = wspace, hspace = hspace)
axlist = list()
mhandles = list()
mlabels = list()
for i, order in enumerate(orders):
irow = int (i / ncol)
icol = i % ncol
ax = fig.add_subplot(gs[irow, icol])
single_plot_score_methods(ax, dscout, 'score1', methods, knots, order,
xscale, yscale, use_median = True)
mpl_utils.set_soft_ylim(ax, 1.0, 1.2, scale = yscale)
ax.set_title(mtitles[order], pad = 10)
mpl_utils.set_xticks(ax, scale = xscale, tickmarks = knots)
mpl_utils.set_yticks(ax, scale = yscale, kmin = 3, kmax = 4, forceticks = [1.0])
mpl_utils.decorate_axes(ax, hide = ["top", "right"], ticklimits = True)
ax.set_xlabel(r"Number of non-zero coefficients (s)")
ax.set_ylabel(r"Prediction Error (RMSE / $\sigma$)")
axlist.append(ax)
for method in methods:
pm = methodprops.plot_metainfo()[method]
mhandles.append(mpatches.Rectangle((0,0), 20, 20, **{'color': pm.color}))
mlabels.append(pm.label)
# Legend
irow = int ((i + 1) / ncol)
icol = (i + 1) % ncol
axl = fig.add_subplot(gs[irow, icol])
axl.tick_params(bottom = False, left = False, labelbottom = False, labelleft = False)
legendtitle = 'Methods'
mhandles, mlabels = axlist[0].get_legend_handles_labels()
#mhandler_map = dict()
#for x in mhandles:
# mhandler_map[x] = HandlerSquare()
legend1 = axl.legend(handles = mhandles, labels = mlabels, #handler_map = mhandler_map, ncol = 2,
handlelength = 3,
loc = 'upper left', bbox_to_anchor = (0, 1.0), frameon = False, title = legendtitle)
mpl_utils.decorate_axes(axl, hide = ["all"], ticklimits = False)
plt.show()