jaxfit.common_jax module

These are functions that were initially in the common.py file, but are have large data operations and are therefore better suited to be compiled with JAX. They are compiled with JAX and then added to the CommonJIT class.

class jaxfit.common_jax.CommonJIT[source]

Bases: object

build_quadratic_1d(J, g, s, diag=None, s0=None)[source]

Parameterize a multivariate quadratic function along a line.

The resulting univariate quadratic function is given as follows:

f(t) = 0.5 * (s0 + s*t).T * (J.T*J + diag) * (s0 + s*t) + g.T * (s0 + s*t)

Parameters:
  • J (ndarray, sparse matrix or LinearOperator, shape (m, n)) – Jacobian matrix, affects the quadratic term.

  • g (ndarray, shape (n,)) – Gradient, defines the linear term.

  • s (ndarray, shape (n,)) – Direction vector of a line.

  • diag (None or ndarray with shape (n,), optional) – Addition diagonal part, affects the quadratic term. If None, assumed to be 0.

  • s0 (None or ndarray with shape (n,), optional) – Initial point. If None, assumed to be 0.

Returns:

  • a (float) – Coefficient for t**2.

  • b (float) – Coefficient for t.

  • c (float) – Free term. Returned only if s0 is provided.

Return type:

Tuple[ndarray, ndarray, ndarray] | Tuple[ndarray, ndarray]

compute_jac_scale(J, scale_inv_old=None)[source]

Compute variables scale based on the Jacobian matrix.

Parameters:
  • J (jnp.ndarray) – Jacobian matrix.

  • scale_inv_old (Optional[np.ndarray], optional) – Previous scale, by default None

Returns:

  • scale (np.ndarray) – Scale for the variables.

  • scale_inv (np.ndarray) – Inverse of the scale for the variables.

Return type:

Tuple[ndarray, ndarray]

create_jac_sum()[source]

Create the function for the sum of the Jacobian squared and then taking the square root. This is used to compute the scale for the variables. Can potentially remove this.

create_js_dot()[source]

Create the functions for the dot product of the Jacobian and the search direction. We need two functions because s and s0 are different shapes which causes retracing of the function if we use the same function for both.

create_quadratic_funcs()[source]
create_scale_for_robust_loss_function()[source]

Create the scaling function for the loss functions

evaluate_quadratic(J, g, s_np, diag=None)[source]

Compute values of a quadratic function arising in least squares. The function is 0.5 * s.T * (J.T * J + diag) * s + g.T * s.

Parameters:
  • J (ndarray, sparse matrix or LinearOperator, shape (m, n)) – Jacobian matrix, affects the quadratic term.

  • g (ndarray, shape (n,)) – Gradient, defines the linear term.

  • s (ndarray, shape (k, n) or (n,)) – Array containing steps as rows.

  • diag (ndarray, shape (n,), optional) – Addition diagonal part, affects the quadratic term. If None, assumed to be 0.

  • s_np (ndarray) –

Returns:

values – Values of the function. If s was 2-D, then ndarray is returned, otherwise, float is returned.

Return type:

ndarray with shape (k,) or float