jaxfit.least_squares module

Generic interface for least-squares minimization.

class jaxfit.least_squares.AutoDiffJacobian[source]

Bases: object

Wraps the residual fit function such that a masked jacfwd is performed on it. thereby giving the autodiff Jacobian. This needs to be a class since we need to maintain in memory three different versions of the Jacobian.

create_ad_jacobian(func, num_args, masked=True)[source]

Creates a function that returns the autodiff jacobian of the residual fit function. The Jacobian of the residual fit function is equivalent to the Jacobian of the fit function.

Parameters:
  • func (Callable) – The function to take the jacobian of.

  • num_args (int) – The number of arguments the function takes.

  • masked (bool, optional) – Whether to use a masked jacobian, by default True

Returns:

The function that returns the autodiff jacobian of the given function.

Return type:

Callable

class jaxfit.least_squares.LeastSquares[source]

Bases: object

autdiff_jac(jac)[source]

We do this for all three sigma transformed functions such that if sigma is changed from none to 1D to covariance sigma then no retracing is needed.

Parameters:

jac (None) – Passed in to maintain compatibility with the user defined Jacobian function.

Return type:

None

least_squares(fun, x0, jac=None, bounds=(-inf, inf), method='trf', ftol=1e-08, xtol=1e-08, gtol=1e-08, x_scale=1.0, loss='linear', f_scale=1.0, diff_step=None, tr_solver=None, tr_options={}, jac_sparsity=None, max_nfev=None, verbose=0, xdata=None, ydata=None, data_mask=None, transform=None, timeit=False, args=(), kwargs={})[source]
Parameters:
  • fun (Callable) –

  • x0 (ndarray) –

  • jac (Callable | None) –

  • bounds (Tuple[ndarray, ndarray]) –

  • method (str) –

  • ftol (float) –

  • xtol (float) –

  • gtol (float) –

  • x_scale (str | ndarray | float) –

  • loss (str) –

  • f_scale (float) –

  • max_nfev (float | None) –

  • verbose (int) –

  • xdata (Array | None) –

  • ydata (Array | None) –

  • data_mask (Array | None) –

  • transform (Array | None) –

  • timeit (bool) –

update_function(func)[source]

Wraps the given fit function to be a residual function using the data. The wrapped function is in a JAX JIT compatible format which is purely functional. This requires that both the data mask and the uncertainty transform are passed to the function. Even for the case where the data mask is all True and the uncertainty transform is None we still need to pass these arguments to the function due JAX’s functional nature.

Parameters:

func (Callable) – The fit function to wrap.

Return type:

None

wrap_jac(jac)[source]

Wraps an user defined Jacobian function to allow for data masking and uncertainty transforms. The wrapped function is in a JAX JIT compatible format which is purely functional. This requires that both the data mask and the uncertainty transform are passed to the function.

Using an analytical Jacobian of the fit function is equivalent to the Jacobian of the residual function.

Also note that the analytical Jacobian doesn’t require the independent ydata, but we still need to pass it to the function to maintain compatibility with autdiff version which does require the ydata.

Parameters:

jac (Callable) – The Jacobian function to wrap.

Returns:

The masked Jacobian of the function evaluated at args with respect to the data.

Return type:

jnp.ndarray

jaxfit.least_squares.check_tolerance(ftol, xtol, gtol, method)[source]

Check and prepare tolerance values for optimization.

This function checks the tolerance values for the optimization and prepares them for use. If any of the tolerances is None, it is set to 0. If any of the tolerances is lower than the machine epsilon, a warning is issued and the tolerance is set to the machine epsilon. If all tolerances are lower than the machine epsilon, a ValueError is raised.

Parameters:
  • ftol (float) – The tolerance for the optimization function value.

  • xtol (float) – The tolerance for the optimization variable values.

  • gtol (float) – The tolerance for the optimization gradient values.

  • method (str) – The name of the optimization method.

Returns:

The prepared tolerance values.

Return type:

Tuple[float, float, float]

jaxfit.least_squares.check_x_scale(x_scale, x0)[source]

Check and prepare the x_scale parameter for optimization.

This function checks and prepares the x_scale parameter for the optimization. x_scale can either be ‘jac’ or an array_like with positive numbers. If it’s ‘jac’ the jacobian is used as the scaling.

Parameters:
  • x_scale (Union[str, Sequence[float]]) – The scaling for the optimization variables.

  • x0 (Sequence[float]) – The initial guess for the optimization variables.

Returns:

The prepared x_scale parameter.

Return type:

Union[str, Sequence[float]]

jaxfit.least_squares.prepare_bounds(bounds, n)[source]

Prepare bounds for optimization.

This function prepares the bounds for the optimization by ensuring that they are both 1-D arrays of length n. If either bound is a scalar, it is resized to an array of length n.

Parameters:
  • bounds (Tuple[np.ndarray, np.ndarray]) – The lower and upper bounds for the optimization.

  • n (int) – The length of the bounds arrays.

Returns:

The prepared lower and upper bounds arrays.

Return type:

Tuple[np.ndarray, np.ndarray]