Source code for barnacle.fista

import numpy as np


[docs] def l1_prox(x, reg): """Proximal operator to apply l1 sparsity penalty. Parameters ---------- x : numpy.ndarray Input array. reg : float L1 sparsity penalty (reg >= 0). Returns ------- x : numpy.ndarray Input array with l1 proximal operation applied. """ sign = np.sign(x) return sign * np.maximum(0, np.abs(x) - reg)
[docs] def nn_prox(x, reg): """Proximal operator to apply non-negativity constraint. Parameters ---------- x : numpy.ndarray Input array. reg : float Not necessary, but included for continuity of the proximal operator call signature. Returns ------- x : numpy.ndarray Input array with non-negativity constraint applied. """ return np.maximum(0, x)
[docs] def l2ball_prox(x, reg): """Proximal operator to apply l2 normalization constraint. Constrains l2-norm of `x` to be less than or equal to 1. Parameters ---------- x : numpy.ndarray Input array. reg : float Not necessary, but included for continuity of the proximal operator call signature. Returns ------- x : numpy.ndarray Input array with l2 norm constraint applied. """ return x / np.maximum(1, np.linalg.norm(x, axis=1, keepdims=True))
[docs] def nn_l1_prox(x, reg): """Proximal operator to apply nonnegative constraint in combination with an l1 sparsity penalty. Parameters ---------- x : numpy.ndarray Input array. reg : float L1 sparsity penalty (reg >= 0). Returns ------- x : numpy.ndarray Input array with non-negativity constraint and l1 proximal operation applied. """ return np.maximum(0, x - reg)
[docs] def nn_l2ball_prox(x, reg): """Proximal operator to apply nonnegative constraint in combination with an l2 normalization constraint. Parameters ---------- x : numpy.ndarray Input array. reg : float Not necessary, but included for continuity of the proximal operator call signature. Returns ------- x : numpy.ndarray Input array with non-negativity and l2 norm constraints applied. """ return l2ball_prox(nn_prox(x, None), None)
def _should_continue_backtracking( new_x, y, loss_new_x, loss_y, smooth_grad_y, lipschitz ): """Determines whether or not a backtracking line search to decrease the step size should continue. Based on the 'FISTA with backtracking' algorithm outlined in Beck & Teboulle (2009). If the step length is large, then a large decrease in loss should also be expected. If instead a large step results in a relatively small decrease in the loss, then the step is likely overshooting the target. To prevent this, the algorithm starts with a large guess for the step length, and decreases it until there is a sufficient decrease in the loss. This allows for maximal initial step sizes to be refined as necessary, resulting in a speed up in the optimization. Parameters ---------- new_x : numpy.ndarray Updated solution vector `x`. y : numpy.ndarray Momentum vector `y`. loss_new_x : numpy.ndarray Loss resulting from updated solution vector `x`. loss_y : numpy.ndarray Loss of momentum vector `y`. smooth_grad_y : numpy.ndarray Gradient of momentum vector `y`. lipschitz : float Lipschitz coefficient. Returns ------- continue_backtracking : bool If True, indicates backtracking line search should continue. """ update_vector = new_x - y update_distance = np.sum(update_vector**2) * lipschitz / 2.5 linearised_improvement = smooth_grad_y.ravel().T @ update_vector.ravel() continue_backtracking = loss_new_x - loss_y > update_distance + linearised_improvement return continue_backtracking
[docs] def create_loss(AtA, At_b): """Helper function to generate loss function. Parameters ---------- AtA : numpy.ndarray The matrix A^T A. At_b : numpy.ndarray The vector A^T b. Returns ------- loss : function Loss function. """ def loss(x): iprod = np.sum(At_b * x) cp_norm = np.sum(AtA * (x @ x.T)) return 0.5 * (cp_norm - 2 * iprod) # + data norm which is constant return loss
[docs] def create_gradient(AtA, At_b): """Helper function to generate gradient function. Parameters ---------- AtA : numpy.ndarray The matrix A^T A. At_b : numpy.ndarray The vector A^T b. Returns ------- grad : function Gradient function. """ def grad(x): return AtA @ x - At_b return grad
[docs] def fista_step( x, y, t, lipschitz, smooth_grad_y, l1_reg, prox ): """Function to take one FISTA step. Parameters ---------- x : numpy.ndarray Initial solution vector `x`. y : numpy.ndarray Initial solution vector `x` with momentum. t : float Momentum coefficient. lipschitz : float Lipschitz coefficient. smooth_grad_y : numpy.ndarray Gradient of `y`. l1_reg : float L1 regularization coefficient. prox : function Proximal operator. Returns ------- new_x : numpy.ndarray Updated solution vector `x`. new_y : numpy.ndarray Updated solution vector `x` with momentum. new_t : float Updated momentum coefficient. """ intermediate_step = (1 / lipschitz) * smooth_grad_y new_x = prox(y - intermediate_step, l1_reg / lipschitz) new_t = 0.5 * (1 + np.sqrt(1 + 4 * t**2)) momentum = (t - 1) / new_t dx = new_x - x new_y = x + momentum * dx return new_x, new_y, new_t
[docs] def minimise_fista( lhs, rhs, init, l1_reg, prox, n_iter=10, tol=1e-6, return_err=False, line_search=True, ): """Use the FISTA algorithm to solve the given optimisation problem: min_x ||Ax - b||^2 + reg * g(x) Optimization is acheived using the Fast Iterative Shrinkage Thresholding Algorithm (FISTA) with backtracking, as described in Beck & Teboulle (2009) :cite:p:`beck2009fast`, in combination with adaptive restart as described in O'Donoghue & Candès (2012) :cite:p:`o2015adaptive`. Parameters ---------- lhs : numpy.ndarray The matrix A^T A. rhs : numpy.ndarray The vector A^T b. init : numpy.ndarray Initialization of `x`. l1_reg : float L1 regularization coefficient `reg`. prox : function The proximal operator `g()`. n_iter : int, default is 10 Maximal number of iterations. tol : float, default is 1e-6 Convergence tolerance. return_err : bool, default is False Return iteration errors if true. line_search : bool, default is True Perform backtracking line search if True. Returns ------- x : numpy.ndarray The solution `x` that minimizes the given optimization problem. losses : list If `return_err` = True, a list of iteration errors. """ losses = [None] * n_iter # if provided data is all zeros, don't run fista, just return zero matrix if np.linalg.norm(lhs) == 0 or np.linalg.norm(rhs) == 0: x = np.zeros_like(init) if return_err: return x, losses[:0] return x A_norm = np.trace(lhs) if line_search: lipschitz = np.trace(lhs) / (2 * lhs.shape[0]) # Lower bound for lipschitz else: lipschitz = np.trace(lhs) # Upper bound for lipschitz AtA = lhs At_b = rhs x = init y = init t = 1 compute_smooth_loss = create_loss(AtA, At_b) compute_smooth_grad = create_gradient(AtA, At_b) loss_x = compute_smooth_loss(x) loss_y = loss_x smooth_grad_y = compute_smooth_grad(y) n_static = 0 for i in range(n_iter): # Simple FISTA update step new_x, new_y, new_t = fista_step( x, y, t, lipschitz=lipschitz, smooth_grad_y=smooth_grad_y, l1_reg=l1_reg, prox=prox, ) loss_new_x = compute_smooth_loss(new_x) # Adaptive restart criterion from Equation 12 in O’Donoghue & Candès (2012). If # the loss is not decreasing monotonically, then the momentum is likely to push # the estimate in the wrong direction, in which case we restart the momentum. if loss_new_x > loss_x: y = x smooth_grad_y = compute_smooth_grad(y) t = 1 new_x, new_y, new_t = fista_step( x, y, t, lipschitz=lipschitz, smooth_grad_y=smooth_grad_y, l1_reg=l1_reg, prox=prox, ) loss_new_x = compute_smooth_loss(new_x) # Backtracking line search # We backtrack at most five times, since the backtracking line search # may diverge for infeasible initial positions. After a few FISTA iterations, # we will have an actual upper bound for the Lipschitz constant since the # line search increases the Lipschitz estimate exponentially. for _line_search_it in range(5): if ( not _should_continue_backtracking( new_x, y, loss_new_x, loss_y, smooth_grad_y, lipschitz ) or not line_search ): break lipschitz *= 1.5 new_x, new_y, new_t = fista_step( x, y, t, lipschitz=lipschitz, smooth_grad_y=smooth_grad_y, l1_reg=l1_reg, prox=prox, ) loss_new_x = compute_smooth_loss(new_x) # Update loop variables prev_x = x x, y, t = new_x, new_y, new_t loss_x = loss_new_x loss_y = compute_smooth_loss(y) smooth_grad_y = compute_smooth_grad(y) losses[i] = loss_x if np.linalg.norm(prev_x - x) * A_norm < tol: n_static += 1 else: n_static = 0 # break after 5 static iterations if n_static > 5: break if return_err: return x, losses[: i + 1] return x
[docs] def fista_solve( lhs, rhs, l1_reg, nonnegative, normalize, init, n_iter_max=100, return_err=False ): """Use the FISTA algorithm to define and solve the optimisation problem: min_x ||Ax - b||^2 + reg * g(x) Where `reg` is an l1 regularisation coefficient, and `g(x)` is a proximal operator applied to `x`. The proximal operator can optionally incorporate (individually or in combination): - l1 regularization - non-negativity constraint - l2 norm constraint: ||x|| <= 1 L1 regularization and l2 norm constraint cannot be applied in combination. Optimization is acheived using the Fast Iterative Shrinkage Thresholding Algorithm (FISTA) with backtracking, as described in Beck & Teboulle (2009) :cite:p:`beck2009fast`, in combination with adaptive restart as described in O'Donoghue & Candès (2012) :cite:p:`o2015adaptive`. Parameters ---------- lhs : numpy.ndarray The matrix A^T A. rhs : numpy.ndarray The vector A^T b. l1_reg : float L1 regularization coefficient `reg`. nonnegative : bool If True, applies a non-negativity constraint to the solution `x`. normalize : bool If True, applies an l2 norm constraint to `x` such that ||x|| <= 1. init : numpy.ndarray Initialization of `x`. n_iter_max : int, default is 100 Maximal number of iterations. return_err : bool, default is False Return iteration errors if true. Returns ------- x : numpy.ndarray The solution `x` that minimizes the given optimization problem. losses : list If `return_err` = True, a list of iteration errors. """ if normalize and l1_reg: raise ValueError('Cannot normalize and apply l1 regularization on same mode.') if l1_reg and nonnegative: prox = nn_l1_prox elif nonnegative and normalize: prox = nn_l2ball_prox elif nonnegative: prox = nn_prox elif normalize: prox = l2ball_prox else: prox = l1_prox return minimise_fista( lhs, rhs, init, l1_reg, prox, n_iter=n_iter_max, tol=1e-6, return_err=return_err )