Source code for knockpy.kpytorch.mrcgrad

""" Gradient-based methods for solving MRC problems.
Currently only used for group-knockoffs."""

import warnings
import time
import numpy as np
import scipy as sp

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..constants import GAMMA_VALS
from .. import utilities, mrc


[docs]def block_diag_sparse(*arrs): """ Given a list of 2D torch tensors, creates a sparse block-diagonal matrix See https://github.com/pytorch/pytorch/issues/31942 """ bad_args = [] for k, arr in enumerate(arrs): # if isinstance(arr, nn.Parameter): # arr = arr.data if not (isinstance(arr, torch.Tensor) and arr.ndim == 2): bad_args.append(k) if len(bad_args) != 0: raise ValueError(f"Args in {bad_args} positions must be 2D tensor") shapes = torch.tensor([a.shape for a in arrs]) out = torch.zeros( torch.sum(shapes, dim=0).tolist(), dtype=arrs[0].dtype, device=arrs[0].device ) r, c = 0, 0 for i, (rr, cc) in enumerate(shapes): out[r : r + rr, c : c + cc] = arrs[i] r += rr c += cc return out
[docs]class MVRLoss(nn.Module): """ A pytorch class to compute S-matrices for (gaussian) MX knockoffs which minimizes the trace of the feature-knockoff precision matrix (the inverse of the feature-knockoff covariance/Grahm matrix, G). :param Sigma: p x p numpy matrix. Must already be sorted by groups. :param groups: p length numpy array of groups. These must already be sorted and correspond to the Sigma. :param init_S: The initialization values for the S block-diagonal matrix. - A p x p matrix. The block-diagonal of this matrix, as specified by groups, will be the initial values for the S matrix. - A list of square numpy matrices, with the ith element corresponding to the block of the ith group in S. Default: Half of the identity. :param rec_prop: The proportion of data you are planning to recycle. (The optimal S matrix depends on the recycling proportion.) :param rec_prop: The proportion of knockoffs that will be recycled. :param smoothing: Calculate the loss as sum 1/(eigs + smoothing) as opposed to sum 1/eigs. This is helpful if fitting lasso statistics on extremely degenerate covariance matrices. Over the course of optimization, this smoothing parameter will go to 0. :param method: One of mvr or maxent (mmi for backwards compatability). """ def __init__( self, Sigma, groups, init_S=None, invSigma=None, rec_prop=0, smoothing=0.01, min_smoothing=1e-4, method="mvr", ): super().__init__() # Groups MUST be sorted sorted_groups = np.sort(groups) if not np.all(groups == sorted_groups): raise ValueError("Sigma and groups must be sorted prior to input") # Save sigma and groups self.p = Sigma.shape[0] self.groups = torch.from_numpy(groups).long() self.group_sizes = torch.from_numpy(utilities.calc_group_sizes(groups)).long() self.Sigma = torch.from_numpy(Sigma).float() # Save inverse cov matrix, rec_prop if invSigma is None: invSigma = utilities.chol2inv(Sigma) self.invSigma = torch.from_numpy(invSigma).float() # Save recycling proportion and smoothing self.smoothing = smoothing self.min_smoothing = min_smoothing self.rec_prop = rec_prop self.method = method # Make sure init_S is a numpy array if init_S is None: # If nothing provided, default to equicorrelated scale = min(1, 2 * utilities.calc_mineig(Sigma)) init_S = scale * np.eye(self.p) elif isinstance(init_S, list): # Check for correct number of blocks num_blocks = len(init_S) num_groups = np.unique(groups).shape[0] if num_blocks != num_groups: raise ValueError( f"Length of init_S {num_blocks} doesn't agree with num groups {num_groups}" ) init_S = sp.linalg.block_diag(*init_S) # Find a good initial scaling best_gamma = 1 best_loss = np.inf if method == "mvr": objective = mrc.mvr_loss else: objective = mrc.maxent_loss for gamma in GAMMA_VALS: loss = objective(Sigma=Sigma, S=(1 - self.rec_prop) * gamma * init_S,) if loss >= 0 and loss < best_loss: best_gamma = gamma best_loss = loss init_S = best_gamma * init_S # Create new blocks blocks = utilities.blockdiag_to_blocks(init_S, groups) # Torch-ify and take sqrt blocks = [torch.from_numpy(block) for block in blocks] blocks = [torch.linalg.cholesky(block) for block in blocks] # Save self.blocks = [nn.Parameter(block.float()) for block in blocks] # Register the blocks as parameters for i, block in enumerate(self.blocks): self.register_parameter(f"block{i}", block) self.update_sqrt_S() self.scale_sqrt_S(tol=1e-5, num_iter=10)
[docs] def update_sqrt_S(self): """ Updates sqrt_S using the block parameters """ self.sqrt_S = block_diag_sparse(*self.blocks)
[docs] def pull_S(self): """ Returns the S matrix """ self.update_sqrt_S() S = torch.mm(self.sqrt_S.t(), self.sqrt_S) return S
[docs] def forward(self, smoothing=None): """ Calculates trace of inverse grahm feature-knockoff matrix""" # TODO: This certainly works and is more efficient in a forward # pass than taking the eigenvalues of both S and 2*Sigma - S. # But perhaps the dot product makes the backprop less efficient? # Infer smoothing if smoothing is None: smoothing = self.smoothing # Create schurr complement S = self.pull_S() S = (1 - self.rec_prop) * S # Account for recycling calcing loss diff = self.Sigma - S G_schurr = self.Sigma - torch.mm(torch.mm(diff, self.invSigma), diff) # Take eigenvalues eigvals = torch.linalg.eigvalsh( G_schurr, UPLO="U" ) if self.method == "mvr": inv_eigs = 1 / (smoothing + eigvals) elif self.method == "maxent": inv_eigs = torch.log( 1 / torch.max((smoothing + eigvals), torch.tensor(smoothing).float()), ) return inv_eigs.sum()
[docs] def scale_sqrt_S(self, tol, num_iter): """ Scales sqrt_S such that 2 Sigma - S is PSD.""" # No gradients with torch.no_grad(): # This shift only applies for for block in self.blocks: if block.data.shape[0] == 1: block.data = torch.max(torch.tensor(tol).float(), block.data) # Construct S S = self.pull_S() # Find optimal scaling _, gamma = utilities.scale_until_PSD( self.Sigma.numpy(), S.numpy(), tol=tol, num_iter=num_iter ) # Scale blocks for block in self.blocks: block.data = np.sqrt(gamma) * block.data self.update_sqrt_S()
[docs] def project(self, **kwargs): """ Project by scaling sqrt_S """ self.scale_sqrt_S(**kwargs)
[docs]class PSGDSolver: """ Projected gradient descent to solve for MRC knockoffs. This will work for non-convex loss objectives as well, although it's a heuristic optimization method. :param Sigma: p x p numpy array, the correlation matrix :param groups: p-length numpy array specifying groups :param losscalc: A pytorch class wrapping nn.module which contains the following methods: - .forward() which calculates the loss based on the internally stored S matrix. - .project() which ensures that both the internally-stored S matrix as well as (2*Sigma - S) are PSD. - .pull_S(), which returns the internally-stored S matrix. If None, creates a MVRLoss class. :param lr: Initial learning rate (default 1e-2) :param verbose: if true, reports progress :param max_epochs: Maximum number of epochs in SGD :param tol: Mimimum eigenvalue allowed for PSD matrices :param line_search_iter: Number of line searches to do when scaling sqrt_S. :param convergence_tol: After each projection, we calculate improvement = 2/3 * ||prev_opt_S - opt_S||_1 + 1/3 * (improvement) When improvement < convergence_tol, we return. :param kwargs: Passed to MVRLoss """ def __init__( self, Sigma, groups, losscalc=None, lr=1e-2, verbose=False, max_epochs=100, tol=1e-5, line_search_iter=10, convergence_tol=1e-1, **kwargs, ): # Add Sigma self.p = Sigma.shape[0] self.Sigma = Sigma self.groups = groups self.opt_S = None # Output initialization self.opt_loss = np.inf # Save parameters for optimization self.lr = lr self.verbose = verbose self.max_epochs = max_epochs self.tol = tol self.line_search_iter = line_search_iter self.convergence_tol = convergence_tol # Sort by groups for ease of computation inds, inv_inds = utilities.permute_matrix_by_groups(groups) self.inds = inds self.inv_inds = inv_inds self.sorted_Sigma = self.Sigma[inds][:, inds] self.sorted_groups = self.groups[inds] # Loss calculator if losscalc is not None: self.losscalc = losscalc else: self.losscalc = MVRLoss( Sigma=self.sorted_Sigma, groups=self.sorted_groups, **kwargs ) # Initialize cache of optimal S with torch.no_grad(): init_loss = self.losscalc(smoothing=0) if init_loss < 0: init_loss = np.inf self.cache_S(init_loss) # Initialize attributes which save losses over time self.all_losses = [] self.projected_losses = [] self.improvements = [] def cache_S(self, new_loss): # Cache optimal solution with torch.no_grad(): self.prev_opt_S = self.opt_S self.prev_opt_loss = self.opt_loss self.opt_loss = new_loss self.opt_S = self.losscalc.pull_S().clone().detach().numpy()
[docs] def optimize(self): """ See __init__ for arguments. """ # Optimizer params = list(self.losscalc.parameters()) optimizer = torch.optim.Adam(params, lr=self.lr) improvement = self.convergence_tol + 10 for j in range(self.max_epochs): # Step 1: Calculate loss (trace of feature-knockoff precision) loss = self.losscalc() if np.isnan(loss.detach().item()): warnings.warn(f"Loss of {self.losscalc.method} solver is NaN") break self.all_losses.append(loss.item()) # Step 2: Step along the graient optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() # Step 3: Reproject to be PSD if j % 3 == 0 or j == self.max_epochs - 1: self.losscalc.project(tol=self.tol, num_iter=self.line_search_iter) # If this is optimal after reprojecting, save with torch.no_grad(): new_loss = self.losscalc(smoothing=0) if new_loss < self.opt_loss and new_loss >= 0: self.cache_S(new_loss) else: self.prev_opt_S = self.opt_S self.prev_opt_loss = self.opt_loss # Cache projected loss self.projected_losses.append(new_loss.item()) # Calculate improvement if j != 0 and j % 10 == 0: diff = self.prev_opt_loss - self.opt_loss l1diff = np.abs(self.opt_S - self.prev_opt_S).sum() improvement = 2 * (diff) / 3 + improvement / 3 if self.verbose: print( f"L1 diff is {l1diff}, loss diff={diff}, improvement is {improvement}, best loss is {self.opt_loss} at iter {j}" ) self.improvements.append(improvement) # Break if improvement is small if improvement < self.convergence_tol and j % 10 == 0: if self.losscalc.smoothing > self.losscalc.min_smoothing: improvement = 1 + convergence_tol # Reset self.losscalc.smoothing = max( self.losscalc.min_smoothing, self.losscalc.smoothing / 10 ) if self.verbose: print( f"Nearing convergence, reducing smoothing to {self.losscalc.smoothing} \n" ) elif self.verbose: print(f"Converged at iteration {j}") break # Shift, scale, and return sorted_S = self.opt_S S = sorted_S[self.inv_inds][:, self.inv_inds] S = utilities.shift_until_PSD(S, tol=self.tol) S, _ = utilities.scale_until_PSD( self.Sigma, S, tol=self.tol, num_iter=self.line_search_iter ) return S
[docs]def solve_mrc_psgd( Sigma, groups=None, method="mvr", **kwargs, ): """ Wraps the PSGDSolver class. :param Sigma: Covariance matrix :param groups: groups for group knockoffs :param method: MRC loss (mvr or maxent) :param init_kwargs: kwargs to pass to PSGDSolver. :param optimize_kwargs: kwargs to pass to optimizer method. :returns: opt_S """ solver = PSGDSolver(Sigma=Sigma, groups=groups, method=method, **kwargs) opt_S = solver.optimize() return opt_S