Source code for jaxfit.common_jax

"""
These are functions that were initially in the common.py file, but are have
large data operations and are therefore better suited to be compiled with
JAX.  They are compiled with JAX and then added to the CommonJIT class.
"""
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit
import time
from typing import Tuple, List, Dict, Union, Optional, Callable

EPS = np.finfo(float).eps


[docs]class CommonJIT(): def __init__(self): """Initialize the class and create the JAX/JIT functions that will be compiled""" self.create_quadratic_funcs() self.create_js_dot() self.create_jac_sum() self.create_scale_for_robust_loss_function()
[docs] def create_scale_for_robust_loss_function(self): """Create the scaling function for the loss functions""" @jit def scale_for_robust_loss_function(J: jnp.ndarray, f: jnp.ndarray, rho: jnp.ndarray ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Scale Jacobian and residuals for a robust loss function. Arrays are modified in place. Parameters ---------- J : jnp.ndarray Jacobian matrix. f : jnp.ndarray Residuals. rho : jnp.ndarray Cost function evaluation. """ J_scale = rho[1] + 2 * rho[2] * f**2 mask = J_scale < EPS J_scale = jnp.where(mask, EPS, J_scale) J_scale = J_scale**0.5 fscale = (rho[1] / J_scale) f = f * fscale J = J * J_scale[:, jnp.newaxis] return J, f self.scale_for_robust_loss_function = scale_for_robust_loss_function
[docs] def build_quadratic_1d(self, J: jnp.ndarray, g: jnp.ndarray, s: jnp.ndarray, diag: Optional[jnp.ndarray] = None, s0: Optional[jnp.ndarray] = None ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]: """Parameterize a multivariate quadratic function along a line. The resulting univariate quadratic function is given as follows: f(t) = 0.5 * (s0 + s*t).T * (J.T*J + diag) * (s0 + s*t) + g.T * (s0 + s*t) Parameters ---------- J : ndarray, sparse matrix or LinearOperator, shape (m, n) Jacobian matrix, affects the quadratic term. g : ndarray, shape (n,) Gradient, defines the linear term. s : ndarray, shape (n,) Direction vector of a line. diag : None or ndarray with shape (n,), optional Addition diagonal part, affects the quadratic term. If None, assumed to be 0. s0 : None or ndarray with shape (n,), optional Initial point. If None, assumed to be 0. Returns ------- a : float Coefficient for t**2. b : float Coefficient for t. c : float Free term. Returned only if `s0` is provided. """ s_jnp = jnp.array(s) v_jnp = self.js_dot(J, s_jnp) v = v_jnp.copy() a = np.dot(v, v) if diag is not None: a += np.dot(s * diag, s) a *= 0.5 b = np.dot(g, s) if s0 is not None: s0_jnp = jnp.array(s0) u_jnp = self.js0_dot(J, s0_jnp) u = u_jnp.copy() b += np.dot(u, v) c = 0.5 * np.dot(u, u) + np.dot(g, s0) if diag is not None: b += np.dot(s0 * diag, s) c += 0.5 * np.dot(s0 * diag, s0) return a, b, c else: return a, b
[docs] def compute_jac_scale(self, J: jnp.ndarray, scale_inv_old: Optional[np.ndarray] = None ) -> Tuple[np.ndarray, np.ndarray]: """Compute variables scale based on the Jacobian matrix. Parameters ---------- J : jnp.ndarray Jacobian matrix. scale_inv_old : Optional[np.ndarray], optional Previous scale, by default None Returns ------- scale : np.ndarray Scale for the variables. scale_inv : np.ndarray Inverse of the scale for the variables. """ scale_inv_jnp = self.jac_sum_func(J) scale_inv = np.array(scale_inv_jnp) if scale_inv_old is None: scale_inv[scale_inv == 0] = 1 else: scale_inv = np.maximum(scale_inv, scale_inv_old) return 1 / scale_inv, scale_inv
[docs] def create_js_dot(self): """Create the functions for the dot product of the Jacobian and the search direction. We need two functions because s and s0 are different shapes which causes retracing of the function if we use the same function for both. """ @jit def js_dot(J: jnp.ndarray, s: jnp.ndarray) -> jnp.ndarray: return J.dot(s) @jit def js0_dot(J: jnp.ndarray, s0: jnp.ndarray) -> jnp.ndarray: return J.dot(s0) self.js_dot = js_dot self.js0_dot = js0_dot
[docs] def evaluate_quadratic(self, J: jnp.ndarray, g: jnp.ndarray, s_np: np.ndarray, diag: Optional[np.ndarray] = None ) -> jnp.ndarray: """Compute values of a quadratic function arising in least squares. The function is 0.5 * s.T * (J.T * J + diag) * s + g.T * s. Parameters ---------- J : ndarray, sparse matrix or LinearOperator, shape (m, n) Jacobian matrix, affects the quadratic term. g : ndarray, shape (n,) Gradient, defines the linear term. s : ndarray, shape (k, n) or (n,) Array containing steps as rows. diag : ndarray, shape (n,), optional Addition diagonal part, affects the quadratic term. If None, assumed to be 0. Returns ------- values : ndarray with shape (k,) or float Values of the function. If `s` was 2-D, then ndarray is returned, otherwise, float is returned. """ s = jnp.array(s_np) #comes in as np array if s.ndim == 1: if diag is None: return self.evaluate_quadratic1(J, g, s) else: return self.evaluate_quadratic_diagonal1(J, g, s, diag) else: if diag is None: return self.evaluate_quadratic2(J, g, s) else: return self.evaluate_quadratic_diagonal2(J, g, s, diag)
[docs] def create_quadratic_funcs(self): @jit def evaluate_quadratic1(J, g, s): Js = J.dot(s) q = jnp.dot(Js, Js) l = jnp.dot(s, g) return 0.5 * q + l @jit def evaluate_quadratic_diagonal1(J, g, s, diag): Js = J.dot(s) q = jnp.dot(Js, Js) + jnp.dot(s * diag, s) l = jnp.dot(s, g) return 0.5 * q + l @jit def evaluate_quadratic2(J, g, s): Js = J.dot(s.T) q = jnp.sum(Js**2, axis=0) l = jnp.dot(s, g) return 0.5 * q + l @jit def evaluate_quadratic_diagonal2(J, g, s, diag): Js = J.dot(s.T) q = jnp.sum(Js**2, axis=0) + jnp.sum(diag * s**2, axis=1) l = jnp.dot(s, g) return 0.5 * q + l self.evaluate_quadratic1 = evaluate_quadratic1 self.evaluate_quadratic_diagonal1 = evaluate_quadratic_diagonal1 self.evaluate_quadratic2 = evaluate_quadratic2 self.evaluate_quadratic_diagonal2 = evaluate_quadratic_diagonal2
[docs] def create_jac_sum(self): """Create the function for the sum of the Jacobian squared and then taking the square root. This is used to compute the scale for the variables. Can potentially remove this. """ @jit def jac_sum_func(J): return jnp.sum(J**2, axis=0)**0.5 self.jac_sum_func = jac_sum_func