jaxfit.minpack module

class jaxfit.minpack.CurveFit(flength=None)[source]

Bases: object

Parameters:

flength (float | None) –

create_covariance_svd()[source]

Create JIT-compiled SVD function for covariance computation.

create_sigma_transform_funcs()[source]

Create JIT-compiled sigma transform functions.

This function creates two JIT-compiled functions: sigma_transform1d and sigma_transform2d, which are used to compute the sigma transform for 1D and 2D data, respectively. The functions are stored as attributes of the object on which the method is called.

curve_fit(f, xdata, ydata, p0=None, sigma=None, absolute_sigma=False, check_finite=True, bounds=(-inf, inf), method=None, jac=None, data_mask=None, timeit=False, return_eval=False, **kwargs)[source]

Use non-linear least squares to fit a function, f, to data. Assumes ydata = f(xdata, *params) + eps.

Parameters:
  • f (callable) – The model function, f(x, …). It must take the independent variable as the first argument and the parameters to fit as separate remaining arguments.

  • xdata (array_like or object) – The independent variable where the data is measured. Should usually be an M-length sequence or an (k,M)-shaped array for functions with k predictors, but can actually be any object.

  • ydata (array_like) – The dependent data, a length M array - nominally f(xdata, ...).

  • p0 (array_like, optional) – Initial guess for the parameters (length N). If None, then the initial values will all be 1 (if the number of parameters for the function can be determined using introspection, otherwise a ValueError is raised).

  • sigma (None or M-length sequence or MxM array, optional) –

    Determines the uncertainty in ydata. If we define residuals as r = ydata - f(xdata, *popt), then the interpretation of sigma depends on its number of dimensions: - A 1-D sigma should contain values of standard deviations of errors in ydata. In this case, the optimized function is chisq = sum((r / sigma) ** 2). - A 2-D sigma should contain the covariance matrix of errors in ydata. In this case, the optimized function is chisq = r.T @ inv(sigma) @ r. .. versionadded:: 0.19

    None (default) is equivalent of 1-D sigma filled with ones.

  • absolute_sigma (bool, optional) – If True, sigma is used in an absolute sense and the estimated parameter covariance pcov reflects these absolute values. If False (default), only the relative magnitudes of the sigma values matter. The returned parameter covariance matrix pcov is based on scaling sigma by a constant factor. This constant is set by demanding that the reduced chisq for the optimal parameters popt when using the scaled sigma equals unity. In other words, sigma is scaled to match the sample variance of the residuals after the fit. Default is False. Mathematically, pcov(absolute_sigma=False) = pcov(absolute_sigma=True) * chisq(popt)/(M-N)

  • check_finite (bool, optional) – If True, check that the input arrays do not contain nans of infs, and raise a ValueError if they do. Setting this parameter to False may silently produce nonsensical results if the input arrays do contain nans. Default is True.

  • bounds (2-tuple of array_like, optional) – Lower and upper bounds on parameters. Defaults to no bounds. Each element of the tuple must be either an array with the length equal to the number of parameters, or a scalar (in which case the bound is taken to be the same for all parameters). Use np.inf with an appropriate sign to disable bounds on all or some parameters. .. versionadded:: 0.17

  • method ({'trf'}, optional) – Method to use for optimization. See least_squares for more details. Currently only ‘trf’ is implemented. .. versionadded:: 0.17

  • jac (callable, string or None, optional) – Function with signature jac(x, ...) which computes the Jacobian matrix of the model function with respect to parameters as a dense array_like structure. It will be scaled according to provided sigma. If None (default), the Jacobian will be determined using JAX’s automatic differentiation (AD) capabilities. We recommend not using an analytical Jacobian, as it is usually faster to use AD.

  • kwargs – Keyword arguments passed to leastsq for method='lm' or least_squares otherwise.

  • data_mask (ndarray | None) –

  • timeit (bool) –

  • return_eval (bool) –

Returns:

  • popt (array) – Optimal values for the parameters so that the sum of the squared residuals of f(xdata, *popt) - ydata is minimized.

  • pcov (2-D array) – The estimated covariance of popt. The diagonals provide the variance of the parameter estimate. To compute one standard deviation errors on the parameters use perr = np.sqrt(np.diag(pcov)). How the sigma parameter affects the estimated covariance depends on absolute_sigma argument, as described above. If the Jacobian matrix at the solution doesn’t have a full rank, then ‘lm’ method returns a matrix filled with np.inf, on the other hand ‘trf’ and ‘dogbox’ methods use Moore-Penrose pseudoinverse to compute the covariance matrix.

Raises:
  • ValueError – if either ydata or xdata contain NaNs, or if incompatible options are used.

  • RuntimeError – if the least-squares minimization fails.

  • OptimizeWarning – if covariance of the parameters can not be estimated.

Return type:

Tuple[ndarray, ndarray]

See also

least_squares

Minimize the sum of squares of nonlinear functions.

Notes

Refer to the docstring of least_squares for more information.

Examples

>>> import matplotlib.pyplot as plt
>>> import jax.numpy as jnp
>>> from jaxfit import CurveFit
>>> def func(x, a, b, c):
...     return a * jnp.exp(-b * x) + c
Define the data to be fit with some noise:
>>> xdata = np.linspace(0, 4, 50)
>>> y = func(xdata, 2.5, 1.3, 0.5)
>>> rng = np.random.default_rng()
>>> y_noise = 0.2 * rng.normal(size=xdata.size)
>>> ydata = y + y_noise
>>> plt.plot(xdata, ydata, 'b-', label='data')
Fit for the parameters a, b, c of the function `func`:
>>> cf = CurveFit()
>>> popt, pcov = cf.curve_fit(func, xdata, ydata)
>>> popt
array([2.56274217, 1.37268521, 0.47427475])
>>> plt.plot(xdata, func(xdata, *popt), 'r-',
...          label='fit: a=%5.3f, b=%5.3f, c=%5.3f' % tuple(popt))
Constrain the optimization to the region of ``0 <= a <= 3``,
``0 <= b <= 1`` and ``0 <= c <= 0.5``:
>>> cf = CurveFit()
>>> popt, pcov = cf.curve_fit(func, xdata, ydata, bounds=(0, [3., 1., 0.5]))
>>> popt
array([2.43736712, 1.        , 0.34463856])
>>> plt.plot(xdata, func(xdata, *popt), 'g--',
...          label='fit: a=%5.3f, b=%5.3f, c=%5.3f' % tuple(popt))
>>> plt.xlabel('x')
>>> plt.ylabel('y')
>>> plt.legend()
>>> plt.show()
pad_fit_data(xdata, ydata, xdims, len_diff)[source]

Pad fit data to match the fixed input data length.

This function pads the input data arrays with small values to match the fixed input data length to avoid JAX retracing the JITted functions. The padding is added along the second dimension of the xdata array if it’s multidimensional data otherwise along the first dimension. The small values are chosen to be EPS, a global constant defined as a very small positive value which avoids numerical issues.

Parameters:
  • xdata (np.ndarray) – The independent variables of the data.

  • ydata (np.ndarray) – The dependent variables of the data.

  • xdims (int) – The number of dimensions in the xdata array.

  • len_diff (int) – The difference in length between the data arrays and the fixed input data length.

Returns:

The padded xdata and ydata arrays.

Return type:

Tuple[np.ndarray, np.ndarray]

update_flength(flength)[source]

Set the fixed input data length.

Parameters:

flength (float) – The fixed input data length.

jaxfit.minpack.curve_fit(f, *args, **kwargs)[source]