API Documentation

Module contents

Created on Mon Mar 14 19:26:59 2022

@author: hofer

class jaxfit.AutoDiffJacobian[source]

Bases: object

Wraps the residual fit function such that a masked jacfwd is performed on it. thereby giving the autodiff Jacobian. This needs to be a class since we need to maintain in memory three different versions of the Jacobian.

create_ad_jacobian(func, num_args, masked=True)[source]

Creates a function that returns the autodiff jacobian of the residual fit function. The Jacobian of the residual fit function is equivalent to the Jacobian of the fit function.

Parameters:
  • func (Callable) – The function to take the jacobian of.

  • num_args (int) – The number of arguments the function takes.

  • masked (bool, optional) – Whether to use a masked jacobian, by default True

Returns:

The function that returns the autodiff jacobian of the given function.

Return type:

Callable

jaxfit.CL_scaling_vector(x, g, lb, ub)[source]

Compute Coleman-Li scaling vector and its derivatives. Components of a vector v are defined as follows:

       | ub[i] - x[i], if g[i] < 0 and ub[i] < np.inf
v[i] = | x[i] - lb[i], if g[i] > 0 and lb[i] > -np.inf
       | 1,           otherwise

According to this definition v[i] >= 0 for all i. It differs from the definition in paper [5] (eq. (2.2)), where the absolute value of v is used. Both definitions are equivalent down the line. Derivatives of v with respect to x take value 1, -1 or 0 depending on a case.

Returns:

  • v (ndarray with shape of x) – Scaling vector.

  • dv (ndarray with shape of x) – Derivatives of v[i] with respect to x[i], diagonal elements of v’s Jacobian.

References

class jaxfit.CommonJIT[source]

Bases: object

build_quadratic_1d(J, g, s, diag=None, s0=None)[source]

Parameterize a multivariate quadratic function along a line.

The resulting univariate quadratic function is given as follows:

f(t) = 0.5 * (s0 + s*t).T * (J.T*J + diag) * (s0 + s*t) + g.T * (s0 + s*t)

Parameters:
  • J (ndarray, sparse matrix or LinearOperator, shape (m, n)) – Jacobian matrix, affects the quadratic term.

  • g (ndarray, shape (n,)) – Gradient, defines the linear term.

  • s (ndarray, shape (n,)) – Direction vector of a line.

  • diag (None or ndarray with shape (n,), optional) – Addition diagonal part, affects the quadratic term. If None, assumed to be 0.

  • s0 (None or ndarray with shape (n,), optional) – Initial point. If None, assumed to be 0.

Returns:

  • a (float) – Coefficient for t**2.

  • b (float) – Coefficient for t.

  • c (float) – Free term. Returned only if s0 is provided.

Return type:

Tuple[ndarray, ndarray, ndarray] | Tuple[ndarray, ndarray]

compute_jac_scale(J, scale_inv_old=None)[source]

Compute variables scale based on the Jacobian matrix.

Parameters:
  • J (jnp.ndarray) – Jacobian matrix.

  • scale_inv_old (Optional[np.ndarray], optional) – Previous scale, by default None

Returns:

  • scale (np.ndarray) – Scale for the variables.

  • scale_inv (np.ndarray) – Inverse of the scale for the variables.

Return type:

Tuple[ndarray, ndarray]

create_jac_sum()[source]

Create the function for the sum of the Jacobian squared and then taking the square root. This is used to compute the scale for the variables. Can potentially remove this.

create_js_dot()[source]

Create the functions for the dot product of the Jacobian and the search direction. We need two functions because s and s0 are different shapes which causes retracing of the function if we use the same function for both.

create_quadratic_funcs()[source]
create_scale_for_robust_loss_function()[source]

Create the scaling function for the loss functions

evaluate_quadratic(J, g, s_np, diag=None)[source]

Compute values of a quadratic function arising in least squares. The function is 0.5 * s.T * (J.T * J + diag) * s + g.T * s.

Parameters:
  • J (ndarray, sparse matrix or LinearOperator, shape (m, n)) – Jacobian matrix, affects the quadratic term.

  • g (ndarray, shape (n,)) – Gradient, defines the linear term.

  • s (ndarray, shape (k, n) or (n,)) – Array containing steps as rows.

  • diag (ndarray, shape (n,), optional) – Addition diagonal part, affects the quadratic term. If None, assumed to be 0.

  • s_np (ndarray) –

Returns:

values – Values of the function. If s was 2-D, then ndarray is returned, otherwise, float is returned.

Return type:

ndarray with shape (k,) or float

class jaxfit.CurveFit(flength=None)[source]

Bases: object

Parameters:

flength (float | None) –

create_covariance_svd()[source]

Create JIT-compiled SVD function for covariance computation.

create_sigma_transform_funcs()[source]

Create JIT-compiled sigma transform functions.

This function creates two JIT-compiled functions: sigma_transform1d and sigma_transform2d, which are used to compute the sigma transform for 1D and 2D data, respectively. The functions are stored as attributes of the object on which the method is called.

curve_fit(f, xdata, ydata, p0=None, sigma=None, absolute_sigma=False, check_finite=True, bounds=(-inf, inf), method=None, jac=None, data_mask=None, timeit=False, return_eval=False, **kwargs)[source]

Use non-linear least squares to fit a function, f, to data. Assumes ydata = f(xdata, *params) + eps.

Parameters:
  • f (callable) – The model function, f(x, …). It must take the independent variable as the first argument and the parameters to fit as separate remaining arguments.

  • xdata (array_like or object) – The independent variable where the data is measured. Should usually be an M-length sequence or an (k,M)-shaped array for functions with k predictors, but can actually be any object.

  • ydata (array_like) – The dependent data, a length M array - nominally f(xdata, ...).

  • p0 (array_like, optional) – Initial guess for the parameters (length N). If None, then the initial values will all be 1 (if the number of parameters for the function can be determined using introspection, otherwise a ValueError is raised).

  • sigma (None or M-length sequence or MxM array, optional) –

    Determines the uncertainty in ydata. If we define residuals as r = ydata - f(xdata, *popt), then the interpretation of sigma depends on its number of dimensions: - A 1-D sigma should contain values of standard deviations of errors in ydata. In this case, the optimized function is chisq = sum((r / sigma) ** 2). - A 2-D sigma should contain the covariance matrix of errors in ydata. In this case, the optimized function is chisq = r.T @ inv(sigma) @ r. .. versionadded:: 0.19

    None (default) is equivalent of 1-D sigma filled with ones.

  • absolute_sigma (bool, optional) – If True, sigma is used in an absolute sense and the estimated parameter covariance pcov reflects these absolute values. If False (default), only the relative magnitudes of the sigma values matter. The returned parameter covariance matrix pcov is based on scaling sigma by a constant factor. This constant is set by demanding that the reduced chisq for the optimal parameters popt when using the scaled sigma equals unity. In other words, sigma is scaled to match the sample variance of the residuals after the fit. Default is False. Mathematically, pcov(absolute_sigma=False) = pcov(absolute_sigma=True) * chisq(popt)/(M-N)

  • check_finite (bool, optional) – If True, check that the input arrays do not contain nans of infs, and raise a ValueError if they do. Setting this parameter to False may silently produce nonsensical results if the input arrays do contain nans. Default is True.

  • bounds (2-tuple of array_like, optional) – Lower and upper bounds on parameters. Defaults to no bounds. Each element of the tuple must be either an array with the length equal to the number of parameters, or a scalar (in which case the bound is taken to be the same for all parameters). Use np.inf with an appropriate sign to disable bounds on all or some parameters. .. versionadded:: 0.17

  • method ({'trf'}, optional) – Method to use for optimization. See least_squares for more details. Currently only ‘trf’ is implemented. .. versionadded:: 0.17

  • jac (callable, string or None, optional) – Function with signature jac(x, ...) which computes the Jacobian matrix of the model function with respect to parameters as a dense array_like structure. It will be scaled according to provided sigma. If None (default), the Jacobian will be determined using JAX’s automatic differentiation (AD) capabilities. We recommend not using an analytical Jacobian, as it is usually faster to use AD.

  • kwargs – Keyword arguments passed to leastsq for method='lm' or least_squares otherwise.

  • data_mask (ndarray | None) –

  • timeit (bool) –

  • return_eval (bool) –

Returns:

  • popt (array) – Optimal values for the parameters so that the sum of the squared residuals of f(xdata, *popt) - ydata is minimized.

  • pcov (2-D array) – The estimated covariance of popt. The diagonals provide the variance of the parameter estimate. To compute one standard deviation errors on the parameters use perr = np.sqrt(np.diag(pcov)). How the sigma parameter affects the estimated covariance depends on absolute_sigma argument, as described above. If the Jacobian matrix at the solution doesn’t have a full rank, then ‘lm’ method returns a matrix filled with np.inf, on the other hand ‘trf’ and ‘dogbox’ methods use Moore-Penrose pseudoinverse to compute the covariance matrix.

Raises:
  • ValueError – if either ydata or xdata contain NaNs, or if incompatible options are used.

  • RuntimeError – if the least-squares minimization fails.

  • OptimizeWarning – if covariance of the parameters can not be estimated.

Return type:

Tuple[ndarray, ndarray]

See also

least_squares

Minimize the sum of squares of nonlinear functions.

Notes

Refer to the docstring of least_squares for more information.

Examples

>>> import matplotlib.pyplot as plt
>>> import jax.numpy as jnp
>>> from jaxfit import CurveFit
>>> def func(x, a, b, c):
...     return a * jnp.exp(-b * x) + c
Define the data to be fit with some noise:
>>> xdata = np.linspace(0, 4, 50)
>>> y = func(xdata, 2.5, 1.3, 0.5)
>>> rng = np.random.default_rng()
>>> y_noise = 0.2 * rng.normal(size=xdata.size)
>>> ydata = y + y_noise
>>> plt.plot(xdata, ydata, 'b-', label='data')
Fit for the parameters a, b, c of the function `func`:
>>> cf = CurveFit()
>>> popt, pcov = cf.curve_fit(func, xdata, ydata)
>>> popt
array([2.56274217, 1.37268521, 0.47427475])
>>> plt.plot(xdata, func(xdata, *popt), 'r-',
...          label='fit: a=%5.3f, b=%5.3f, c=%5.3f' % tuple(popt))
Constrain the optimization to the region of ``0 <= a <= 3``,
``0 <= b <= 1`` and ``0 <= c <= 0.5``:
>>> cf = CurveFit()
>>> popt, pcov = cf.curve_fit(func, xdata, ydata, bounds=(0, [3., 1., 0.5]))
>>> popt
array([2.43736712, 1.        , 0.34463856])
>>> plt.plot(xdata, func(xdata, *popt), 'g--',
...          label='fit: a=%5.3f, b=%5.3f, c=%5.3f' % tuple(popt))
>>> plt.xlabel('x')
>>> plt.ylabel('y')
>>> plt.legend()
>>> plt.show()
pad_fit_data(xdata, ydata, xdims, len_diff)[source]

Pad fit data to match the fixed input data length.

This function pads the input data arrays with small values to match the fixed input data length to avoid JAX retracing the JITted functions. The padding is added along the second dimension of the xdata array if it’s multidimensional data otherwise along the first dimension. The small values are chosen to be EPS, a global constant defined as a very small positive value which avoids numerical issues.

Parameters:
  • xdata (np.ndarray) – The independent variables of the data.

  • ydata (np.ndarray) – The dependent variables of the data.

  • xdims (int) – The number of dimensions in the xdata array.

  • len_diff (int) – The difference in length between the data arrays and the fixed input data length.

Returns:

The padded xdata and ydata arrays.

Return type:

Tuple[np.ndarray, np.ndarray]

update_flength(flength)[source]

Set the fixed input data length.

Parameters:

flength (float) – The fixed input data length.

class jaxfit.LeastSquares[source]

Bases: object

autdiff_jac(jac)[source]

We do this for all three sigma transformed functions such that if sigma is changed from none to 1D to covariance sigma then no retracing is needed.

Parameters:

jac (None) – Passed in to maintain compatibility with the user defined Jacobian function.

Return type:

None

least_squares(fun, x0, jac=None, bounds=(-inf, inf), method='trf', ftol=1e-08, xtol=1e-08, gtol=1e-08, x_scale=1.0, loss='linear', f_scale=1.0, diff_step=None, tr_solver=None, tr_options={}, jac_sparsity=None, max_nfev=None, verbose=0, xdata=None, ydata=None, data_mask=None, transform=None, timeit=False, args=(), kwargs={})[source]
Parameters:
  • fun (Callable) –

  • x0 (ndarray) –

  • jac (Callable | None) –

  • bounds (Tuple[ndarray, ndarray]) –

  • method (str) –

  • ftol (float) –

  • xtol (float) –

  • gtol (float) –

  • x_scale (str | ndarray | float) –

  • loss (str) –

  • f_scale (float) –

  • max_nfev (float | None) –

  • verbose (int) –

  • xdata (Array | None) –

  • ydata (Array | None) –

  • data_mask (Array | None) –

  • transform (Array | None) –

  • timeit (bool) –

update_function(func)[source]

Wraps the given fit function to be a residual function using the data. The wrapped function is in a JAX JIT compatible format which is purely functional. This requires that both the data mask and the uncertainty transform are passed to the function. Even for the case where the data mask is all True and the uncertainty transform is None we still need to pass these arguments to the function due JAX’s functional nature.

Parameters:

func (Callable) – The fit function to wrap.

Return type:

None

wrap_jac(jac)[source]

Wraps an user defined Jacobian function to allow for data masking and uncertainty transforms. The wrapped function is in a JAX JIT compatible format which is purely functional. This requires that both the data mask and the uncertainty transform are passed to the function.

Using an analytical Jacobian of the fit function is equivalent to the Jacobian of the residual function.

Also note that the analytical Jacobian doesn’t require the independent ydata, but we still need to pass it to the function to maintain compatibility with autdiff version which does require the ydata.

Parameters:

jac (Callable) – The Jacobian function to wrap.

Returns:

The masked Jacobian of the function evaluated at args with respect to the data.

Return type:

jnp.ndarray

exception jaxfit.LinAlgError

Bases: ValueError

Generic Python-exception-derived object raised by linalg functions.

General purpose exception class, derived from Python’s ValueError class, programmatically raised in linalg functions when a Linear Algebra-related condition would prevent further correct execution of the function.

Parameters:

None

Examples

>>> from numpy import linalg as LA
>>> LA.inv(np.zeros((2,2)))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "...linalg.py", line 350,
    in inv return wrap(solve(a, identity(a.shape[0], dtype=a.dtype)))
  File "...linalg.py", line 249,
    in solve
    raise LinAlgError('Singular matrix')
numpy.linalg.LinAlgError: Singular matrix
class jaxfit.LinearOperator(*args, **kwargs)[source]

Bases: object

Common interface for performing matrix vector products

Many iterative methods (e.g. cg, gmres) do not need to know the individual entries of a matrix to solve a linear system A*x=b. Such solvers only require the computation of matrix vector products, A*v where v is a dense vector. This class serves as an abstract interface between iterative solvers and matrix-like objects.

To construct a concrete LinearOperator, either pass appropriate callables to the constructor of this class, or subclass it.

A subclass must implement either one of the methods _matvec and _matmat, and the attributes/properties shape (pair of integers) and dtype (may be None). It may call the __init__ on this class to have these attributes validated. Implementing _matvec automatically implements _matmat (using a naive algorithm) and vice-versa.

Optionally, a subclass may implement _rmatvec or _adjoint to implement the Hermitian adjoint (conjugate transpose). As with _matvec and _matmat, implementing either _rmatvec or _adjoint implements the other automatically. Implementing _adjoint is preferable; _rmatvec is mostly there for backwards compatibility.

Parameters:
  • shape (tuple) – Matrix dimensions (M, N).

  • matvec (callable f(v)) – Returns returns A * v.

  • rmatvec (callable f(v)) – Returns A^H * v, where A^H is the conjugate transpose of A.

  • matmat (callable f(V)) – Returns A * V, where V is a dense matrix with dimensions (N, K).

  • dtype (dtype) – Data type of the matrix.

  • rmatmat (callable f(V)) – Returns A^H * V, where V is a dense matrix with dimensions (M, K).

args

For linear operators describing products etc. of other linear operators, the operands of the binary operation.

Type:

tuple

ndim

Number of dimensions (this is always 2)

Type:

int

See also

aslinearoperator

Construct LinearOperators

Notes

The user-defined matvec() function must properly handle the case where v has shape (N,) as well as the (N,1) case. The shape of the return type is handled internally by LinearOperator.

LinearOperator instances can also be multiplied, added with each other and exponentiated, all lazily: the result of these operations is always a new, composite LinearOperator, that defers linear operations to the original operators and combines the results.

More details regarding how to subclass a LinearOperator and several examples of concrete LinearOperator instances can be found in the external project PyLops.

Examples

>>> import numpy as np
>>> from scipy.sparse.linalg import LinearOperator
>>> def mv(v):
...     return np.array([2*v[0], 3*v[1]])
...
>>> A = LinearOperator((2,2), matvec=mv)
>>> A
<2x2 _CustomLinearOperator with dtype=float64>
>>> A.matvec(np.ones(2))
array([ 2.,  3.])
>>> A * np.ones(2)
array([ 2.,  3.])
property H

Hermitian adjoint.

Returns the Hermitian adjoint of self, aka the Hermitian conjugate or Hermitian transpose. For a complex matrix, the Hermitian adjoint is equal to the conjugate transpose.

Can be abbreviated self.H instead of self.adjoint().

Returns:

A_H – Hermitian adjoint of self.

Return type:

LinearOperator

property T

Transpose this linear operator.

Returns a LinearOperator that represents the transpose of this one. Can be abbreviated self.T instead of self.transpose().

adjoint()[source]

Hermitian adjoint.

Returns the Hermitian adjoint of self, aka the Hermitian conjugate or Hermitian transpose. For a complex matrix, the Hermitian adjoint is equal to the conjugate transpose.

Can be abbreviated self.H instead of self.adjoint().

Returns:

A_H – Hermitian adjoint of self.

Return type:

LinearOperator

dot(x)[source]

Matrix-matrix or matrix-vector multiplication.

Parameters:

x (array_like) – 1-d or 2-d array, representing a vector or matrix.

Returns:

Ax – 1-d or 2-d array (depending on the shape of x) that represents the result of applying this linear operator on x.

Return type:

array

matmat(X)[source]

Matrix-matrix multiplication.

Performs the operation y=A*X where A is an MxN linear operator and X dense N*K matrix or ndarray.

Parameters:

X ({matrix, ndarray}) – An array with shape (N,K).

Returns:

Y – A matrix or ndarray with shape (M,K) depending on the type of the X argument.

Return type:

{matrix, ndarray}

Notes

This matmat wraps any user-specified matmat routine or overridden _matmat method to ensure that y has the correct type.

matvec(x)[source]

Matrix-vector multiplication.

Performs the operation y=A*x where A is an MxN linear operator and x is a column vector or 1-d array.

Parameters:

x ({matrix, ndarray}) – An array with shape (N,) or (N,1).

Returns:

y – A matrix or ndarray with shape (M,) or (M,1) depending on the type and shape of the x argument.

Return type:

{matrix, ndarray}

Notes

This matvec wraps the user-specified matvec routine or overridden _matvec method to ensure that y has the correct shape and type.

ndim = 2
rmatmat(X)[source]

Adjoint matrix-matrix multiplication.

Performs the operation y = A^H * x where A is an MxN linear operator and x is a column vector or 1-d array, or 2-d array. The default implementation defers to the adjoint.

Parameters:

X ({matrix, ndarray}) – A matrix or 2D array.

Returns:

Y – A matrix or 2D array depending on the type of the input.

Return type:

{matrix, ndarray}

Notes

This rmatmat wraps the user-specified rmatmat routine.

rmatvec(x)[source]

Adjoint matrix-vector multiplication.

Performs the operation y = A^H * x where A is an MxN linear operator and x is a column vector or 1-d array.

Parameters:

x ({matrix, ndarray}) – An array with shape (M,) or (M,1).

Returns:

y – A matrix or ndarray with shape (N,) or (N,1) depending on the type and shape of the x argument.

Return type:

{matrix, ndarray}

Notes

This rmatvec wraps the user-specified rmatvec routine or overridden _rmatvec method to ensure that y has the correct shape and type.

transpose()[source]

Transpose this linear operator.

Returns a LinearOperator that represents the transpose of this one. Can be abbreviated self.T instead of self.transpose().

class jaxfit.LossFunctionsJIT[source]

Bases: object

arctan(z, cost_only)[source]
cauchy(z, cost_only)[source]
construct_all_loss_functions()[source]
construct_single_loss_function(loss)[source]
create_arctan_funcs()[source]
create_calculate_cost()[source]
create_cauchy_funcs()[source]
create_huber_funcs()[source]
create_scale_rhos()[source]
create_soft_l1_funcs()[source]
create_stack_rhos()[source]
create_zscale()[source]
get_empty_rhos(z)[source]
get_loss_function(loss)[source]
huber(z, cost_only)[source]
soft_l1(z, cost_only)[source]
class jaxfit.OptimizeResult[source]

Bases: dict

Represents the optimization result. .. attribute:: x

The solution of the optimization.

type:

ndarray

success

Whether or not the optimizer exited successfully.

Type:

bool

status

Termination status of the optimizer. Its value depends on the underlying solver. Refer to message for details.

Type:

int

message

Description of the cause of the termination.

Type:

str

fun, jac, hess

Values of objective function, its Jacobian and its Hessian (if available). The Hessians may be approximations, see the documentation of the function in question.

Type:

ndarray

hess_inv

Inverse of the objective function’s Hessian; may be an approximation. Not available for all solvers. The type of this attribute may be either np.ndarray or scipy.sparse.linalg.LinearOperator.

Type:

object

nfev, njev, nhev

Number of evaluations of the objective functions and of its Jacobian and Hessian.

Type:

int

nit

Number of iterations performed by the optimizer.

Type:

int

maxcv

The maximum constraint violation.

Type:

float

Notes

OptimizeResult may have additional attributes not listed here depending on the specific solver being used. Since this class is essentially a subclass of dict with attribute accessors, one can see which attributes are available using the OptimizeResult.keys method.

exception jaxfit.OptimizeWarning[source]

Bases: UserWarning

class jaxfit.TrustRegionJITFunctions[source]

Bases: object

JIT functions for trust region algorithm.

create_calculate_cost()[source]

Create the function to calculate the cost function.

create_check_isfinite()[source]

Create the function to check if the evaluated residuals are finite.

create_default_loss_func()[source]

Create the default loss function which is simply the sum of the squares of the residuals.

create_grad_func()[source]

Create the function to compute the gradient of the loss function which is simply the function evaluation dotted with the Jacobian.

create_grad_hat()[source]

Calculate the gradient in the “hat” space, which is just multiplying the gradient by the diagonal matrix D. This is used in the trust region algorithm. Here we only use the diagonals of D, since D is diagonal.

create_svd_funcs()[source]

Create the functions to compute the SVD of the Jacobian matrix. There are two versions, one for problems with bounds and one for problems without bounds. The version for problems with bounds is slightly more complicated.

class jaxfit.TrustRegionReflective[source]

Bases: TrustRegionJITFunctions

select_step(x, J_h, diag_h, g_h, p, p_h, d, Delta, lb, ub, theta)[source]

Select the best step according to Trust Region Reflective algorithm.

Parameters:
  • x (np.ndarray) – Current set parameter vector.

  • J_h (jnp.ndarray) – Jacobian matrix in the scaled ‘hat’ space.

  • diag_h (jnp.ndarray) – Diagonal of the scaled matrix C = diag(g * scale) Jv?

  • g_h (jnp.ndarray) – Gradient vector in the scaled ‘hat’ space.

  • p (np.ndarray) – Trust-region step in the original space.

  • p_h (np.ndarray) – Trust-region step in the scaled ‘hat’ space.

  • d (np.ndarray) – Scaling vector.

  • Delta (float) – Trust-region radius.

  • lb (np.ndarray) – Lower bounds on variables.

  • ub (np.ndarray) – Upper bounds on variables.

  • theta (float) – Controls step back step ratio from the bounds.

Returns:

  • step (np.ndarray) – Step in the original space.

  • step_h (np.ndarray) – Step in the scaled ‘hat’ space.

  • predicted_reduction (float) – Predicted reduction in the cost function.

trf(fun, xdata, ydata, jac, data_mask, transform, x0, f0, J0, lb, ub, ftol, xtol, gtol, max_nfev, f_scale, x_scale, loss_function, tr_options, verbose, timeit=False)[source]

Minimize a scalar function of one or more variables using the trust-region reflective algorithm. Although I think this is not good coding style, I maintained the original code format from SciPy such that the code is easier to compare with the original. See the note from the algorithms original author below.

For efficiency, it makes sense to run the simplified version of the algorithm when no bounds are imposed. We decided to write the two separate functions. It violates the DRY principle, but the individual functions are kept the most readable.

Parameters:
  • fun (callable) – The residual function

  • xdata (array_like or tuple of array_like) – The independent variable where the data is measured. If xdata is a tuple, then the input arguments to fun are assumed to be (xdata[0], xdata[1], ...).

  • ydata (jnp.ndarray) – The dependent data

  • jac (callable) – The Jacobian of fun.

  • data_mask (jnp.ndarray) – The mask for the data.

  • transform (jnp.ndarray) – The uncertainty transform for the data.

  • x0 (jnp.ndarray) – Initial guess. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • f0 (jnp.ndarray) – Initial residuals. Array of real elements of size (m,), where ‘m’ is the number of data points.

  • J0 (jnp.ndarray) – Initial Jacobian. Array of real elements of size (m, n), where ‘m’ is the number of data points and ‘n’ is the number of independent variables.

  • lb (jnp.ndarray) – Lower bounds on independent variables. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • ub (jnp.ndarray) – Upper bounds on independent variables. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • ftol (float) – Tolerance for termination by the change of the cost function.

  • xtol (float) – Tolerance for termination by the change of the independent variables.

  • gtol (float) – Tolerance for termination by the norm of the gradient.

  • max_nfev (int) – Maximum number of function evaluations.

  • f_scale (float) – Cost function scalar

  • x_scale (jnp.ndarray) – Scaling factors for independent variables.

  • loss_function (callable, optional) – Loss function. If None, the standard least-squares problem is solved.

  • tr_options (dict) – Options for the trust-region algorithm.

  • verbose (int) –

    Level of algorithm’s verbosity:

    • 0 (default) : work silently.

    • 1 : display a termination report.

  • timeit (bool, optional) – If True, the time for each step is measured if the unbounded version is being ran. Default is False.

Return type:

Dict

trf_bounds(fun, xdata, ydata, jac, data_mask, transform, x0, f, J, lb, ub, ftol, xtol, gtol, max_nfev, f_scale, x_scale, loss_function, tr_options, verbose)[source]

Bounded version of the trust-region reflective algorithm.

Parameters:
  • fun (callable) – The residual function

  • xdata (array_like or tuple of array_like) – The independent variable where the data is measured. If xdata is a tuple, then the input arguments to fun are assumed to be (xdata[0], xdata[1], ...).

  • ydata (jnp.ndarray) – The dependent data

  • jac (callable) – The Jacobian of fun.

  • data_mask (jnp.ndarray) – The mask for the data.

  • transform (jnp.ndarray) – The uncertainty transform for the data.

  • x0 (jnp.ndarray) – Initial guess. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • f0 (jnp.ndarray) – Initial residuals. Array of real elements of size (m,), where ‘m’ is the number of data points.

  • J0 (jnp.ndarray) – Initial Jacobian. Array of real elements of size (m, n), where ‘m’ is the number of data points and ‘n’ is the number of independent variables.

  • lb (jnp.ndarray) – Lower bounds on independent variables. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • ub (jnp.ndarray) – Upper bounds on independent variables. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • ftol (float) – Tolerance for termination by the change of the cost function.

  • xtol (float) – Tolerance for termination by the change of the independent variables.

  • gtol (float) – Tolerance for termination by the norm of the gradient.

  • max_nfev (int) – Maximum number of function evaluations.

  • f_scale (float) – Cost function scalar

  • x_scale (jnp.ndarray) – Scaling factors for independent variables.

  • loss_function (callable, optional) – Loss function. If None, the standard least-squares problem is solved.

  • tr_options (dict) – Options for the trust-region algorithm.

  • verbose (int) –

    Level of algorithm’s verbosity:

    • 0 (default) : work silently.

    • 1 : display a termination report.

  • f (Array) –

  • J (Array) –

Returns:

result – The optimization result represented as a OptimizeResult object. Important attributes are: x the solution array, success a Boolean flag indicating if the optimizer exited successfully and message which describes the cause of the termination. See OptimizeResult for a description of other attributes.

Return type:

OptimizeResult

Notes

The algorithm is described in [13].

References

trf_no_bounds(fun, xdata, ydata, jac, data_mask, transform, x0, f, J, lb, ub, ftol, xtol, gtol, max_nfev, f_scale, x_scale, loss_function, tr_options, verbose)[source]

Unbounded version of the trust-region reflective algorithm.

Parameters:
  • fun (callable) – The residual function

  • xdata (array_like or tuple of array_like) – The independent variable where the data is measured. If xdata is a tuple, then the input arguments to fun are assumed to be (xdata[0], xdata[1], ...).

  • ydata (jnp.ndarray) – The dependent data

  • jac (callable) – The Jacobian of fun.

  • data_mask (jnp.ndarray) – The mask for the data.

  • transform (jnp.ndarray) – The uncertainty transform for the data.

  • x0 (jnp.ndarray) – Initial guess. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • f0 (jnp.ndarray) – Initial residuals. Array of real elements of size (m,), where ‘m’ is the number of data points.

  • J0 (jnp.ndarray) – Initial Jacobian. Array of real elements of size (m, n), where ‘m’ is the number of data points and ‘n’ is the number of independent variables.

  • lb (jnp.ndarray) – Lower bounds on independent variables. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • ub (jnp.ndarray) – Upper bounds on independent variables. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • ftol (float) – Tolerance for termination by the change of the cost function.

  • xtol (float) – Tolerance for termination by the change of the independent variables.

  • gtol (float) – Tolerance for termination by the norm of the gradient.

  • max_nfev (int) – Maximum number of function evaluations.

  • f_scale (float) – Cost function scalar

  • x_scale (jnp.ndarray) – Scaling factors for independent variables.

  • loss_function (callable, optional) – Loss function. If None, the standard least-squares problem is solved.

  • tr_options (dict) – Options for the trust-region algorithm.

  • verbose (int) –

    Level of algorithm’s verbosity:

    • 0 (default) : work silently.

    • 1 : display a termination report.

  • f (Array) –

  • J (Array) –

Returns:

result – The optimization result represented as a OptimizeResult object. Important attributes are: x the solution array, success a Boolean flag indicating if the optimizer exited successfully and message which describes the cause of the termination. See OptimizeResult for a description of other attributes.

Return type:

OptimizeResult

Notes

The algorithm is described in [13].

trf_no_bounds_timed(fun, xdata, ydata, jac, data_mask, transform, x0, f, J, lb, ub, ftol, xtol, gtol, max_nfev, f_scale, x_scale, loss_function, tr_options, verbose)[source]

Trust Region Reflective algorithm with no bounds and all the operations performed on JAX and the GPU are timed. We need a separate function for this because to time each operation we need a block_until_ready() function which makes the main Python thread wait until the GPU has finished the operation. However, for the main algorithm we don’t want to wait for the GPU to finish each operation because it would slow down the algorithm. Thus, this is just used for analysis of the algorithm.

Parameters:
  • fun (callable) – The residual function

  • xdata (array_like or tuple of array_like) – The independent variable where the data is measured. If xdata is a tuple, then the input arguments to fun are assumed to be (xdata[0], xdata[1], ...).

  • ydata (jnp.ndarray) – The dependent data

  • jac (callable) – The Jacobian of fun.

  • data_mask (jnp.ndarray) – The mask for the data.

  • transform (jnp.ndarray) – The uncertainty transform for the data.

  • x0 (jnp.ndarray) – Initial guess. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • f0 (jnp.ndarray) – Initial residuals. Array of real elements of size (m,), where ‘m’ is the number of data points.

  • J0 (jnp.ndarray) – Initial Jacobian. Array of real elements of size (m, n), where ‘m’ is the number of data points and ‘n’ is the number of independent variables.

  • lb (jnp.ndarray) – Lower bounds on independent variables. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • ub (jnp.ndarray) – Upper bounds on independent variables. Array of real elements of size (n,), where ‘n’ is the number of independent variables.

  • ftol (float) – Tolerance for termination by the change of the cost function.

  • xtol (float) – Tolerance for termination by the change of the independent variables.

  • gtol (float) – Tolerance for termination by the norm of the gradient.

  • max_nfev (int) – Maximum number of function evaluations.

  • f_scale (float) – Cost function scalar

  • x_scale (jnp.ndarray) – Scaling factors for independent variables.

  • loss_function (callable, optional) – Loss function. If None, the standard least-squares problem is solved.

  • tr_options (dict) – Options for the trust-region algorithm.

  • verbose (int) –

    Level of algorithm’s verbosity:

    • 0 (default) : work silently.

    • 1 : display a termination report.

  • f (Array) –

  • J (Array) –

Returns:

result – The optimization result represented as a OptimizeResult object. Important attributes are: x the solution array, success a Boolean flag indicating if the optimizer exited successfully and message which describes the cause of the termination. See OptimizeResult for a description of other attributes.

Return type:

OptimizeResult

Notes

The algorithm is described in [13].

jaxfit.aslinearoperator(A)[source]

Return A as a LinearOperator.

‘A’ may be any of the following types:
  • ndarray

  • matrix

  • sparse matrix (e.g. csr_matrix, lil_matrix, etc.)

  • LinearOperator

  • An object with .shape and .matvec attributes

See the LinearOperator documentation for additional information.

Notes

If ‘A’ has no .dtype attribute, the data type is determined by calling LinearOperator.matvec() - set the .dtype attribute to prevent this call upon the linear operator creation.

Examples

>>> import numpy as np
>>> from scipy.sparse.linalg import aslinearoperator
>>> M = np.array([[1,2,3],[4,5,6]], dtype=np.int32)
>>> aslinearoperator(M)
<2x3 MatrixLinearOperator with dtype=int32>
jaxfit.check_termination(dF, F, dx_norm, x_norm, ratio, ftol, xtol)[source]

Check termination condition for nonlinear least squares.

jaxfit.check_tolerance(ftol, xtol, gtol, method)[source]

Check and prepare tolerance values for optimization.

This function checks the tolerance values for the optimization and prepares them for use. If any of the tolerances is None, it is set to 0. If any of the tolerances is lower than the machine epsilon, a warning is issued and the tolerance is set to the machine epsilon. If all tolerances are lower than the machine epsilon, a ValueError is raised.

Parameters:
  • ftol (float) – The tolerance for the optimization function value.

  • xtol (float) – The tolerance for the optimization variable values.

  • gtol (float) – The tolerance for the optimization gradient values.

  • method (str) – The name of the optimization method.

Returns:

The prepared tolerance values.

Return type:

Tuple[float, float, float]

jaxfit.check_x_scale(x_scale, x0)[source]

Check and prepare the x_scale parameter for optimization.

This function checks and prepares the x_scale parameter for the optimization. x_scale can either be ‘jac’ or an array_like with positive numbers. If it’s ‘jac’ the jacobian is used as the scaling.

Parameters:
  • x_scale (Union[str, Sequence[float]]) – The scaling for the optimization variables.

  • x0 (Sequence[float]) – The initial guess for the optimization variables.

Returns:

The prepared x_scale parameter.

Return type:

Union[str, Sequence[float]]

jaxfit.cho_factor(a, lower=False, overwrite_a=False, check_finite=True)[source]

Compute the Cholesky decomposition of a matrix, to use in cho_solve

Returns a matrix containing the Cholesky decomposition, A = L L* or A = U* U of a Hermitian positive-definite matrix a. The return value can be directly used as the first parameter to cho_solve.

Warning

The returned matrix also contains random data in the entries not used by the Cholesky decomposition. If you need to zero these entries, use the function cholesky instead.

Parameters:
  • a ((M, M) array_like) – Matrix to be decomposed

  • lower (bool, optional) – Whether to compute the upper or lower triangular Cholesky factorization (Default: upper-triangular)

  • overwrite_a (bool, optional) – Whether to overwrite data in a (may improve performance)

  • check_finite (bool, optional) – Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.

Returns:

  • c ((M, M) ndarray) – Matrix whose upper or lower triangle contains the Cholesky factor of a. Other parts of the matrix contain random data.

  • lower (bool) – Flag indicating whether the factor is in the lower or upper triangle

Raises:

LinAlgError – Raised if decomposition fails.

See also

cho_solve

Solve a linear set equations using the Cholesky factorization of a matrix.

Examples

>>> import numpy as np
>>> from scipy.linalg import cho_factor
>>> A = np.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]])
>>> c, low = cho_factor(A)
>>> c
array([[3.        , 1.        , 0.33333333, 1.66666667],
       [3.        , 2.44948974, 1.90515869, -0.27216553],
       [1.        , 5.        , 2.29330749, 0.8559528 ],
       [5.        , 1.        , 2.        , 1.55418563]])
>>> np.allclose(np.triu(c).T @ np. triu(c) - A, np.zeros((4, 4)))
True
jaxfit.cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True)[source]

Solve the linear equations A x = b, given the Cholesky factorization of A.

Parameters:
  • (c (tuple, (array, bool)) – Cholesky factorization of a, as given by cho_factor

  • lower) (tuple, (array, bool)) – Cholesky factorization of a, as given by cho_factor

  • b (array) – Right-hand side

  • overwrite_b (bool, optional) – Whether to overwrite data in b (may improve performance)

  • check_finite (bool, optional) – Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs.

Returns:

x – The solution to the system A x = b

Return type:

array

See also

cho_factor

Cholesky factorization of a matrix

Examples

>>> import numpy as np
>>> from scipy.linalg import cho_factor, cho_solve
>>> A = np.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]])
>>> c, low = cho_factor(A)
>>> x = cho_solve((c, low), [1, 1, 1, 1])
>>> np.allclose(A @ x - [1, 1, 1, 1], np.zeros(4))
True
jaxfit.copysign(x, y, /)

Return a float with the magnitude (absolute value) of x but the sign of y.

On platforms that support signed zeros, copysign(1.0, -0.0) returns -1.0.

jaxfit.curve_fit(f, *args, **kwargs)[source]
jaxfit.evaluate_quadratic(J, g, s, diag=None)[source]

Compute values of a quadratic function arising in least squares. The function is 0.5 * s.T * (J.T * J + diag) * s + g.T * s.

Parameters:
  • J (ndarray, sparse matrix or LinearOperator, shape (m, n)) – Jacobian matrix, affects the quadratic term.

  • g (ndarray, shape (n,)) – Gradient, defines the linear term.

  • s (ndarray, shape (k, n) or (n,)) – Array containing steps as rows.

  • diag (ndarray, shape (n,), optional) – Addition diagonal part, affects the quadratic term. If None, assumed to be 0.

Returns:

values – Values of the function. If s was 2-D, then ndarray is returned, otherwise, float is returned.

Return type:

ndarray with shape (k,) or float

jaxfit.find_active_constraints(x, lb, ub, rtol=1e-10)[source]

Determine which constraints are active in a given point. The threshold is computed using rtol and the absolute value of the closest bound.

Returns:

active

Each component shows whether the corresponding constraint is active:
  • 0 - a constraint is not active.

  • -1 - a lower bound is active.

  • 1 - a upper bound is active.

Return type:

ndarray of int with shape of x

jaxfit.in_bounds(x, lb, ub)[source]

Check if a point lies within bounds.

jaxfit.intersect_trust_region(x, s, Delta)[source]

Find the intersection of a line with the boundary of a trust region. This function solves the quadratic equation with respect to t ||(x + s*t)||**2 = Delta**2.

Returns:

t_neg, t_pos – Negative and positive roots.

Return type:

tuple of float

Raises:

ValueError – If s is zero or x is not within the trust region.

jaxfit.issparse(x)

Is x of a sparse matrix type?

Parameters:

x – object to check for being a sparse matrix

Returns:

True if x is a sparse matrix, False otherwise

Return type:

bool

Notes

issparse and isspmatrix are aliases for the same function.

Examples

>>> from scipy.sparse import csr_matrix, isspmatrix
>>> isspmatrix(csr_matrix([[5]]))
True
>>> from scipy.sparse import isspmatrix
>>> isspmatrix(5)
False
jaxfit.jacfwd(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Jacobian of fun evaluated column-by-column using forward-mode AD.

Parameters:
  • fun (Callable) – Function whose Jacobian is to be computed.

  • argnums (int | Sequence[int]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

  • has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Returns:

A function with the same arguments as fun, that evaluates the Jacobian of fun using forward-mode automatic differentiation. If has_aux is True then a pair of (jacobian, auxiliary_data) is returned.

Return type:

Callable

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
 [ 0.       0.       5.     ]
 [ 0.      16.      -2.     ]
 [ 1.6209   0.       0.84147]]
jaxfit.jax_solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True)

Solve the equation a x = b for x, assuming a is a triangular matrix.

LAX-backend implementation of scipy.linalg._basic.solve_triangular().

Does not support the Scipy argument check_finite=True, because compiled JAX code cannot perform checks of array values at runtime.

Does not support the Scipy argument overwrite_*=True.

Original docstring below.

Parameters:
  • a ((M, M) array_like) – A triangular matrix

  • b ((M,) or (M, N) array_like) – Right-hand side matrix in a x = b

  • lower (bool, optional) – Use only data contained in the lower triangle of a. Default is to use upper triangle.

  • trans ({0, 1, 2, 'N', 'T', 'C'}, optional) –

    Type of system to solve:

    trans

    system

    0 or ‘N’

    a x = b

    1 or ‘T’

    a^T x = b

    2 or ‘C’

    a^H x = b

  • unit_diagonal (bool, optional) – If True, diagonal elements of a are assumed to be 1 and will not be referenced.

  • overwrite_b (bool) –

  • debug (Any | None) –

  • check_finite (bool) –

Returns:

x – Solution to the system a x = b. Shape of return matches b.

Return type:

(M,) or (M, N) ndarray

jaxfit.jax_svd(a, full_matrices=True, compute_uv=True, overwrite_a=False, check_finite=True, lapack_driver='gesdd')

Singular Value Decomposition.

LAX-backend implementation of scipy.linalg._decomp_svd.svd().

Does not support the Scipy argument check_finite=True, because compiled JAX code cannot perform checks of array values at runtime.

Does not support the Scipy argument overwrite_*=True.

Original docstring below.

Factorizes the matrix a into two unitary matrices U and Vh, and a 1-D array s of singular values (real, non-negative) such that a == U @ S @ Vh, where S is a suitably shaped matrix of zeros with main diagonal s.

Parameters:
  • a ((M, N) array_like) – Matrix to decompose.

  • full_matrices (bool, optional) – If True (default), U and Vh are of shape (M, M), (N, N). If False, the shapes are (M, K) and (K, N), where K = min(M, N).

  • compute_uv (bool, optional) – Whether to compute also U and Vh in addition to s. Default is True.

  • overwrite_a (bool) –

  • check_finite (bool) –

  • lapack_driver (str) –

Returns:

  • U (ndarray) – Unitary matrix having left singular vectors as columns. Of shape (M, M) or (M, K), depending on full_matrices.

  • s (ndarray) – The singular values, sorted in non-increasing order. Of shape (K,), with K = min(M, N).

  • Vh (ndarray) – Unitary matrix having right singular vectors as rows. Of shape (N, N) or (K, N) depending on full_matrices.

  • For compute_uv=False, only s is returned.

Return type:

Array | Tuple[Array, Array, Array]

jaxfit.jit(fun, in_shardings=UnspecifiedValue, out_shardings=UnspecifiedValue, static_argnums=None, static_argnames=None, donate_argnums=(), keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[source]

Sets up fun for just-in-time compilation with XLA.

Parameters:
  • fun (Callable) –

    Function to be jitted. fun should be a pure function, as side-effects may only be executed once.

    The arguments and return value of fun should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.

    JAX keeps a weak reference to fun for use as a compilation cache key, so the object fun must be weakly-referenceable. Most Callable objects will already satisfy this requirement.

  • in_shardings

    Pytree of structure matching that of arguments to fun, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.

    The in_shardings argument is optional. JAX will infer the shardings from the input jax.Array’s and defaults to replicating the input if the sharding cannot be inferred.

    The valid resource assignment specifications are:
    • XLACompatibleSharding, which will decide how the value

      will be partitioned. With this, using a mesh context manager is not required.

    • None, will give JAX the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.

    The size of every dimension has to be a multiple of the total number of resources assigned to it. This is similar to pjit’s in_shardings.

  • out_shardings

    Like in_shardings, but specifies resource assignment for function outputs. This is similar to pjit’s out_shardings.

    The out_shardings argument is optional. If not specified, jax.jit() will use GSPMD’s sharding propagation to figure out what the sharding of the output(s) should be.

  • static_argnums (int | Sequence[int] | None) –

    An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.

    Static arguments should be hashable, meaning both __hash__ and __eq__ are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.

    If neither static_argnums nor static_argnames is provided, no arguments are treated as static. If static_argnums is not provided but static_argnames is, or vice versa, JAX uses inspect.signature(fun) to find any positional arguments that correspond to static_argnames (or vice versa). If both static_argnums and static_argnames are provided, inspect.signature is not used, and only actual parameters listed in either static_argnums or static_argnames will be treated as static.

  • static_argnames (str | Iterable[str] | None) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on static_argnums for details. If not provided but static_argnums is set, the default is based on calling inspect.signature(fun) to find corresponding named arguments.

  • donate_argnums (int | Sequence[int]) –

    Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated. Note that donate_argnums only work for positional arguments, and keyword arguments will not be donated.

    For more details on buffer donation see the FAQ.

  • keep_unused (bool) – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.

  • device (Device | None) – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

  • backend (str | None) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend: 'cpu', 'gpu', or 'tpu'.

  • inline (bool) – Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False.

  • abstracted_axes (Any | None) –

Returns:

A wrapped version of fun, set up for just-in-time compilation.

Return type:

Wrapped

Examples

In the following example, selu can be compiled into a single fused kernel by XLA:

>>> import jax
>>>
>>> @jax.jit
... def selu(x, alpha=1.67, lmbda=1.05):
...   return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.PRNGKey(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x))  
[-0.54485  0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232  0.76827  0.59566 ]

To pass arguments such as static_argnames when decorating a function, a common pattern is to use functools.partial():

>>> from functools import partial
>>>
>>> @partial(jax.jit, static_argnames=['n'])
... def g(x, n):
...   for i in range(n):
...     x = x ** 2
...   return x
>>>
>>> g(jnp.arange(4), 3)
Array([   0,    1,  256, 6561], dtype=int32)
jaxfit.make_strictly_feasible(x, lb, ub, rstep=1e-10)[source]

Shift a point to the interior of a feasible region. Each element of the returned vector is at least at a relative distance rstep from the closest bound. If rstep=0 then np.nextafter is used.

jaxfit.minimize_quadratic_1d(a, b, lb, ub, c=0)[source]

Minimize a 1-D quadratic function subject to bounds. The free term c is 0 by default. Bounds must be finite.

Returns:

  • t (float) – Minimum point.

  • y (float) – Minimum value.

jaxfit.norm(x, ord=None, axis=None, keepdims=False)

Matrix or vector norm.

This function is able to return one of eight different matrix norms, or one of an infinite number of vector norms (described below), depending on the value of the ord parameter.

Parameters:
  • x (array_like) – Input array. If axis is None, x must be 1-D or 2-D, unless ord is None. If both axis and ord are None, the 2-norm of x.ravel will be returned.

  • ord ({non-zero int, inf, -inf, 'fro', 'nuc'}, optional) – Order of the norm (see table under Notes). inf means numpy’s inf object. The default is None.

  • axis ({None, int, 2-tuple of ints}, optional.) –

    If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The default is None.

    New in version 1.8.0.

  • keepdims (bool, optional) –

    If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original x.

    New in version 1.10.0.

Returns:

n – Norm of the matrix or vector(s).

Return type:

float or ndarray

See also

scipy.linalg.norm

Similar function in SciPy.

Notes

For values of ord < 1, the result is, strictly speaking, not a mathematical ‘norm’, but it may still be useful for various numerical purposes.

The following norms can be calculated:

ord

norm for matrices

norm for vectors

None

Frobenius norm

2-norm

‘fro’

Frobenius norm

‘nuc’

nuclear norm

inf

max(sum(abs(x), axis=1))

max(abs(x))

-inf

min(sum(abs(x), axis=1))

min(abs(x))

0

sum(x != 0)

1

max(sum(abs(x), axis=0))

as below

-1

min(sum(abs(x), axis=0))

as below

2

2-norm (largest sing. value)

as below

-2

smallest singular value

as below

other

sum(abs(x)**ord)**(1./ord)

The Frobenius norm is given by [1]:

\(||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}\)

The nuclear norm is the sum of the singular values.

Both the Frobenius and nuclear norm orders are only defined for matrices and raise a ValueError when x.ndim != 2.

References

Examples

>>> from numpy import linalg as LA
>>> a = np.arange(9) - 4
>>> a
array([-4, -3, -2, ...,  2,  3,  4])
>>> b = a.reshape((3, 3))
>>> b
array([[-4, -3, -2],
       [-1,  0,  1],
       [ 2,  3,  4]])
>>> LA.norm(a)
7.745966692414834
>>> LA.norm(b)
7.745966692414834
>>> LA.norm(b, 'fro')
7.745966692414834
>>> LA.norm(a, np.inf)
4.0
>>> LA.norm(b, np.inf)
9.0
>>> LA.norm(a, -np.inf)
0.0
>>> LA.norm(b, -np.inf)
2.0
>>> LA.norm(a, 1)
20.0
>>> LA.norm(b, 1)
7.0
>>> LA.norm(a, -1)
-4.6566128774142013e-010
>>> LA.norm(b, -1)
6.0
>>> LA.norm(a, 2)
7.745966692414834
>>> LA.norm(b, 2)
7.3484692283495345
>>> LA.norm(a, -2)
0.0
>>> LA.norm(b, -2)
1.8570331885190563e-016 # may vary
>>> LA.norm(a, 3)
5.8480354764257312 # may vary
>>> LA.norm(a, -3)
0.0

Using the axis argument to compute vector norms:

>>> c = np.array([[ 1, 2, 3],
...               [-1, 1, 4]])
>>> LA.norm(c, axis=0)
array([ 1.41421356,  2.23606798,  5.        ])
>>> LA.norm(c, axis=1)
array([ 3.74165739,  4.24264069])
>>> LA.norm(c, ord=1, axis=1)
array([ 6.,  6.])

Using the axis argument to compute matrix norms:

>>> m = np.arange(8).reshape(2,2,2)
>>> LA.norm(m, axis=(1,2))
array([  3.74165739,  11.22497216])
>>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :])
(3.7416573867739413, 11.224972160321824)
jaxfit.prepare_bounds(bounds, n)[source]

Prepare bounds for optimization.

This function prepares the bounds for the optimization by ensuring that they are both 1-D arrays of length n. If either bound is a scalar, it is resized to an array of length n.

Parameters:
  • bounds (Tuple[np.ndarray, np.ndarray]) – The lower and upper bounds for the optimization.

  • n (int) – The length of the bounds arrays.

Returns:

The prepared lower and upper bounds arrays.

Return type:

Tuple[np.ndarray, np.ndarray]

jaxfit.print_header_linear()[source]
jaxfit.print_header_nonlinear()[source]
jaxfit.print_iteration_linear(iteration, cost, cost_reduction, step_norm, optimality)[source]
jaxfit.print_iteration_nonlinear(iteration, nfev, cost, cost_reduction, step_norm, optimality)[source]
jaxfit.reflective_transformation(y, lb, ub)[source]

Compute reflective transformation and its gradient.

jaxfit.solve_lsq_trust_region(n, m, uf, s, V, Delta, initial_alpha=None, rtol=0.01, max_iter=10)[source]

Solve a trust-region problem arising in least-squares minimization. This function implements a method described by J. J. More [12] and used in MINPACK, but it relies on a single SVD of Jacobian instead of series of Cholesky decompositions. Before running this function, compute: U, s, VT = svd(J, full_matrices=False).

Parameters:
  • n (int) – Number of variables.

  • m (int) – Number of residuals.

  • uf (ndarray) – Computed as U.T.dot(f).

  • s (ndarray) – Singular values of J.

  • V (ndarray) – Transpose of VT.

  • Delta (float) – Radius of a trust region.

  • initial_alpha (float, optional) – Initial guess for alpha, which might be available from a previous iteration. If None, determined automatically.

  • rtol (float, optional) – Stopping tolerance for the root-finding procedure. Namely, the solution p will satisfy abs(norm(p) - Delta) < rtol * Delta.

  • max_iter (int, optional) – Maximum allowed number of iterations for the root-finding procedure.

Returns:

  • p (ndarray, shape (n,)) – Found solution of a trust-region problem.

  • alpha (float) – Positive value such that (J.T*J + alpha*I)*p = -J.T*f. Sometimes called Levenberg-Marquardt parameter.

  • n_iter (int) – Number of iterations made by root-finding procedure. Zero means that Gauss-Newton step was selected as the solution.

References

jaxfit.solve_trust_region_2d(B, g, Delta)[source]

Solve a general trust-region problem in 2 dimensions. The problem is reformulated as a 4th order algebraic equation, the solution of which is found by numpy.roots.

Parameters:
  • B (ndarray, shape (2, 2)) – Symmetric matrix, defines a quadratic term of the function.

  • g (ndarray, shape (2,)) – Defines a linear term of the function.

  • Delta (float) – Radius of a trust region.

Returns:

  • p (ndarray, shape (2,)) – Found solution.

  • newton_step (bool) – Whether the returned solution is the Newton step which lies within the trust region.

jaxfit.step_size_to_bound(x, s, lb, ub)[source]

Compute a min_step size required to reach a bound. The function computes a positive scalar t, such that x + s * t is on the bound.

Returns:

  • step (float) – Computed step. Non-negative value.

  • hits (ndarray of int with shape of x) – Each element indicates whether a corresponding variable reaches the bound:

    • 0 - the bound was not hit.

    • -1 - the lower bound was hit.

    • 1 - the upper bound was hit.

jaxfit.tree_flatten(tree, is_leaf=None)[source]

Flattens a pytree.

The flattening order (i.e. the order of elements in the output list) is deterministic, corresponding to a left-to-right depth-first tree traversal.

Parameters:
  • tree (Any) – a pytree to flatten.

  • is_leaf (Callable[[Any], bool] | None) – an optionally specified function that will be called at each flattening step. It should return a boolean, with true stopping the traversal and the whole subtree being treated as a leaf, and false indicating the flattening should traverse the current object.

Returns:

A pair where the first element is a list of leaf values and the second element is a treedef representing the structure of the flattened tree.

Return type:

Tuple[List[Any], PyTreeDef]

jaxfit.update_tr_radius(Delta, actual_reduction, predicted_reduction, step_norm, bound_hit)[source]

Update the radius of a trust region based on the cost reduction.

Returns:

  • Delta (float) – New radius.

  • ratio (float) – Ratio between actual and predicted reductions.

jaxfit.warn(message, category=None, stacklevel=1, source=None)

Issue a warning, or maybe ignore it or raise an exception.