About
There are two alternate proposed methods for updating the penalty \mu on the residual. The penalty is equivalent to the step size. Here, we compare the two methods on a simulated dataset. The proposed methods are:
To-Do: Check stopping criterion of Boyd et. al. using Eq. 3.12.
Getting setup
Code
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pymir import mpl_stylesheet
from pymir import mpl_utils
mpl_stylesheet.banskt_presentation(splinecolor = 'black' , dpi = 120 , colors = 'kelly' )
import sys
sys.path.append("../utils/" )
import histogram as mpy_histogram
import simulate as mpy_simulate
import plot_functions as mpy_plotfn
from nnwmf.optimize import IALM
Simulate some toy data
ntrait = 4 # categories / class
ngwas = 500 # N
nsnp = 1000 # P
nfctr = 40 # K
Code
Y, Y_true, L, F, mean, noise_var, sample_indices = mpy_simulate.simulate(ngwas, nsnp, ntrait, nfctr, std = 0.5 , do_shift_mean = False )
Y_cent = mpy_simulate.do_standardize(Y, scale = False )
Y_std = mpy_simulate.do_standardize(Y)
Y_true_cent = mpy_simulate.do_standardize(Y_true, scale = False )
Code
unique_labels = list (range (len (sample_indices)))
class_labels = [None for x in range (ngwas)]
for k, idxs in enumerate (sample_indices):
for i in idxs:
class_labels[i] = k
IALM with default update of step size
Code
rpca = IALM(benchmark = True , max_iter = 1000 , show_progress = True )
rpca.fit(Y_cent, Xtrue = Y_cent)
2023-07-28 13:24:39,445 | nnwmf.optimize.inexact_alm | INFO | Iteration 0. Primal residual 0.892391. Dual residual 0.00374201
2023-07-28 13:24:55,941 | nnwmf.optimize.inexact_alm | INFO | Iteration 100. Primal residual 0.0031319. Dual residual 1.12454e-05
2023-07-28 13:25:11,945 | nnwmf.optimize.inexact_alm | INFO | Iteration 200. Primal residual 2.69409e-06. Dual residual 1.09007e-05
IALM with ADMM update of step size
Code
rpca_admm = IALM(benchmark = True , max_iter = 1000 , mu_update_method= 'admm' , show_progress = True )
rpca_admm.fit(Y_cent, Xtrue = Y_cent)
2023-07-28 13:25:22,352 | nnwmf.optimize.inexact_alm | INFO | Iteration 0. Primal residual 0.892391. Dual residual 0.00374201
2023-07-28 13:25:38,124 | nnwmf.optimize.inexact_alm | INFO | Iteration 100. Primal residual 1.33026e-06. Dual residual 1.61284e-06
Results
Code
fig = plt.figure(figsize = (20 , 6 ))
ax1 = fig.add_subplot(131 )
ax2 = fig.add_subplot(132 )
ax3 = fig.add_subplot(133 )
ax1.plot(range (len (rpca.mu_list_)), rpca.mu_list_, 'o-' , label = 'Default' )
ax1.plot(range (len (rpca_admm.mu_list_)), rpca_admm.mu_list_, 'o-' , label = 'ADMM' )
ax1.legend()
ax1.set_xlabel("Iteration" )
ax1.set_ylabel("Step size" )
ax2.plot(range (len (rpca.primal_res_)), np.log10(rpca.primal_res_), 'o-' , label = 'Primal Residual - Default' )
ax2.plot(range (len (rpca.dual_res_) - 1 ), np.log10(rpca.dual_res_[1 :]), 'o-' , label = 'Dual Residual - Default' )
ax2.plot(range (len (rpca_admm.primal_res_)), np.log10(rpca_admm.primal_res_), 'o-' , label = 'Primal Residual - ADMM' )
ax2.plot(range (len (rpca_admm.dual_res_) - 1 ), np.log10(rpca_admm.dual_res_[1 :]), 'o-' , label = 'Dual Residual - ADMM' )
ax2.legend()
ax2.set_xlabel("Iteration" )
ax2.set_ylabel("Primal and Dual Residuals (log10 scale)" )
ax3.plot(np.log10(np.cumsum(rpca.cpu_time_[1 :])), rpca.rmse_[1 :], 'o-' , label = 'Default' )
ax3.plot(np.log10(np.cumsum(rpca_admm.cpu_time_[1 :])), rpca_admm.rmse_[1 :], 'o-' , label = 'ADMM' )
ax3.legend()
ax3.set_xlabel("CPU Time" )
ax3.set_ylabel("RMSE from input matrix" )
plt.tight_layout(w_pad = 2.0 )
plt.show()
Principal Components for the low rank matrix
Default
Code
Y_rpca_cent = mpy_simulate.do_standardize(rpca.L_, scale = False )
U_rpca, S_rpca, Vt_rpca = np.linalg.svd(Y_rpca_cent, full_matrices = False )
pcomps_rpca = U_rpca @ np.diag(S_rpca)
axmain, axs = mpy_plotfn.plot_principal_components(pcomps_rpca, class_labels, unique_labels)
plt.show()
ADMM
Code
Y_rpca_cent = mpy_simulate.do_standardize(rpca_admm.L_, scale = False )
U_rpca, S_rpca, Vt_rpca = np.linalg.svd(Y_rpca_cent, full_matrices = False )
pcomps_rpca = U_rpca @ np.diag(S_rpca)
axmain, axs = mpy_plotfn.plot_principal_components(pcomps_rpca, class_labels, unique_labels)
plt.show()