import warnings
import numpy as np
from numpy import (zeros, inf)
from inspect import signature
import time
from typing import Optional, Callable, Tuple, Union, List, Dict, Any
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit
from jax.scipy.linalg import svd as jax_svd
from jax.scipy.linalg import cholesky as jax_cholesky
from jaxfit._optimize import OptimizeWarning
from jaxfit.least_squares import prepare_bounds, LeastSquares
from jaxfit.common_scipy import EPS
__all__ = ['CurveFit', 'curve_fit']
[docs]def curve_fit(f, *args, **kwargs):
jcf = CurveFit(f)
popt, pcov, _, _, _ = jcf.curve_fit(*args, **kwargs)
return popt, pcov
def _initialize_feasible(lb: np.ndarray, ub: np.ndarray) -> np.ndarray:
"""Initialize feasible parameters for optimization.
This function initializes feasible parameters for optimization based on the
lower and upper bounds of the variables. If both bounds are finite, the
feasible parameters are set to the midpoint between the bounds. If only the
lower bound is finite, the feasible parameters are set to the lower bound
plus 1. If only the upper bound is finite, the feasible parameters are set
to the upper bound minus 1. If neither bound is finite, the feasible
parameters are set to 1.
Parameters
----------
lb : np.ndarray
The lower bounds of the variables.
ub : np.ndarray
The upper bounds of the variables.
Returns
-------
np.ndarray
The initialized feasible parameters.
"""
p0 = np.ones_like(lb)
lb_finite = np.isfinite(lb)
ub_finite = np.isfinite(ub)
mask = lb_finite & ub_finite
p0[mask] = 0.5 * (lb[mask] + ub[mask])
mask = lb_finite & ~ub_finite
p0[mask] = lb[mask] + 1
mask = ~lb_finite & ub_finite
p0[mask] = ub[mask] - 1
return p0
[docs]class CurveFit():
def __init__(self, flength: Optional[float] = None):
"""CurveFit class for fitting
Parameters
----------
flength : float, optional
fixed data length for fits, JAXFit pads input data to this length
to avoid retracing.
"""
self.flength = flength
self.create_sigma_transform_funcs()
self.create_covariance_svd()
self.ls = LeastSquares()
[docs] def update_flength(self, flength: float):
"""Set the fixed input data length.
Parameters
----------
flength : float
The fixed input data length.
"""
self.flength = flength
[docs] def create_covariance_svd(self):
"""Create JIT-compiled SVD function for covariance computation."""
@jit
def covariance_svd(jac):
_, s, VT = jax_svd(jac, full_matrices=False)
return s, VT
self.covariance_svd = covariance_svd
[docs] def pad_fit_data(self, xdata: np.ndarray,
ydata: np.ndarray,
xdims: int,
len_diff: int
) -> Tuple[np.ndarray, np.ndarray]:
"""Pad fit data to match the fixed input data length.
This function pads the input data arrays with small values to match the
fixed input data length to avoid JAX retracing the JITted functions.
The padding is added along the second dimension of the `xdata` array
if it's multidimensional data otherwise along the first dimension. The
small values are chosen to be `EPS`, a global constant defined as a
very small positive value which avoids numerical issues.
Parameters
----------
xdata : np.ndarray
The independent variables of the data.
ydata : np.ndarray
The dependent variables of the data.
xdims : int
The number of dimensions in the `xdata` array.
len_diff : int
The difference in length between the data arrays and the fixed input data length.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The padded `xdata` and `ydata` arrays.
"""
if xdims > 1:
xpad = EPS * np.ones([xdims, len_diff])
xdata = np.concatenate([xdata, xpad], axis=1)
else:
xpad = EPS * np.ones([len_diff])
xdata = np.concatenate([xdata, xpad])
ypad = EPS * np.ones([len_diff])
ydata = np.concatenate([ydata, ypad])
return xdata, ydata
[docs] def curve_fit(self,
f: Callable,
xdata: Union[np.ndarray, Tuple[np.ndarray]],
ydata: np.ndarray,
p0: Optional[np.ndarray] = None,
sigma: Optional[np.ndarray] = None,
absolute_sigma: bool = False,
check_finite: bool = True,
bounds: Tuple[np.ndarray, np.ndarray] = (-np.inf, np.inf),
method: Optional[str] = None,
jac: Optional[Callable] = None,
data_mask: Optional[np.ndarray] = None,
timeit: bool = False,
return_eval: bool = False,
**kwargs
) -> Tuple[np.ndarray, np.ndarray]:
"""
Use non-linear least squares to fit a function, f, to data.
Assumes ``ydata = f(xdata, *params) + eps``.
Parameters
----------
f : callable
The model function, f(x, ...). It must take the independent
variable as the first argument and the parameters to fit as
separate remaining arguments.
xdata : array_like or object
The independent variable where the data is measured.
Should usually be an M-length sequence or an (k,M)-shaped array for
functions with k predictors, but can actually be any object.
ydata : array_like
The dependent data, a length M array - nominally ``f(xdata, ...)``.
p0 : array_like, optional
Initial guess for the parameters (length N). If None, then the
initial values will all be 1 (if the number of parameters for the
function can be determined using introspection, otherwise a
ValueError is raised).
sigma : None or M-length sequence or MxM array, optional
Determines the uncertainty in `ydata`. If we define residuals as
``r = ydata - f(xdata, *popt)``, then the interpretation of `sigma`
depends on its number of dimensions:
- A 1-D `sigma` should contain values of standard deviations of
errors in `ydata`. In this case, the optimized function is
``chisq = sum((r / sigma) ** 2)``.
- A 2-D `sigma` should contain the covariance matrix of
errors in `ydata`. In this case, the optimized function is
``chisq = r.T @ inv(sigma) @ r``.
.. versionadded:: 0.19
None (default) is equivalent of 1-D `sigma` filled with ones.
absolute_sigma : bool, optional
If True, `sigma` is used in an absolute sense and the estimated parameter
covariance `pcov` reflects these absolute values.
If False (default), only the relative magnitudes of the `sigma` values matter.
The returned parameter covariance matrix `pcov` is based on scaling
`sigma` by a constant factor. This constant is set by demanding that the
reduced `chisq` for the optimal parameters `popt` when using the
*scaled* `sigma` equals unity. In other words, `sigma` is scaled to
match the sample variance of the residuals after the fit. Default is False.
Mathematically,
``pcov(absolute_sigma=False) = pcov(absolute_sigma=True) * chisq(popt)/(M-N)``
check_finite : bool, optional
If True, check that the input arrays do not contain nans of infs,
and raise a ValueError if they do. Setting this parameter to
False may silently produce nonsensical results if the input arrays
do contain nans. Default is True.
bounds : 2-tuple of array_like, optional
Lower and upper bounds on parameters. Defaults to no bounds.
Each element of the tuple must be either an array with the length equal
to the number of parameters, or a scalar (in which case the bound is
taken to be the same for all parameters). Use ``np.inf`` with an
appropriate sign to disable bounds on all or some parameters.
.. versionadded:: 0.17
method : {'trf'}, optional
Method to use for optimization. See `least_squares` for more details.
Currently only 'trf' is implemented.
.. versionadded:: 0.17
jac : callable, string or None, optional
Function with signature ``jac(x, ...)`` which computes the Jacobian
matrix of the model function with respect to parameters as a dense
array_like structure. It will be scaled according to provided `sigma`.
If None (default), the Jacobian will be determined using JAX's automatic
differentiation (AD) capabilities. We recommend not using an analytical
Jacobian, as it is usually faster to use AD.
kwargs
Keyword arguments passed to `leastsq` for ``method='lm'`` or
`least_squares` otherwise.
Returns
-------
popt : array
Optimal values for the parameters so that the sum of the squared
residuals of ``f(xdata, *popt) - ydata`` is minimized.
pcov : 2-D array
The estimated covariance of popt. The diagonals provide the variance
of the parameter estimate. To compute one standard deviation errors
on the parameters use ``perr = np.sqrt(np.diag(pcov))``.
How the `sigma` parameter affects the estimated covariance
depends on `absolute_sigma` argument, as described above.
If the Jacobian matrix at the solution doesn't have a full rank, then
'lm' method returns a matrix filled with ``np.inf``, on the other hand
'trf' and 'dogbox' methods use Moore-Penrose pseudoinverse to compute
the covariance matrix.
Raises
------
ValueError
if either `ydata` or `xdata` contain NaNs, or if incompatible options
are used.
RuntimeError
if the least-squares minimization fails.
OptimizeWarning
if covariance of the parameters can not be estimated.
See Also
--------
least_squares : Minimize the sum of squares of nonlinear functions.
Notes
-----
Refer to the docstring of `least_squares` for more information.
Examples
--------
>>> import matplotlib.pyplot as plt
>>> import jax.numpy as jnp
>>> from jaxfit import CurveFit
>>> def func(x, a, b, c):
... return a * jnp.exp(-b * x) + c
Define the data to be fit with some noise:
>>> xdata = np.linspace(0, 4, 50)
>>> y = func(xdata, 2.5, 1.3, 0.5)
>>> rng = np.random.default_rng()
>>> y_noise = 0.2 * rng.normal(size=xdata.size)
>>> ydata = y + y_noise
>>> plt.plot(xdata, ydata, 'b-', label='data')
Fit for the parameters a, b, c of the function `func`:
>>> cf = CurveFit()
>>> popt, pcov = cf.curve_fit(func, xdata, ydata)
>>> popt
array([2.56274217, 1.37268521, 0.47427475])
>>> plt.plot(xdata, func(xdata, *popt), 'r-',
... label='fit: a=%5.3f, b=%5.3f, c=%5.3f' % tuple(popt))
Constrain the optimization to the region of ``0 <= a <= 3``,
``0 <= b <= 1`` and ``0 <= c <= 0.5``:
>>> cf = CurveFit()
>>> popt, pcov = cf.curve_fit(func, xdata, ydata, bounds=(0, [3., 1., 0.5]))
>>> popt
array([2.43736712, 1. , 0.34463856])
>>> plt.plot(xdata, func(xdata, *popt), 'g--',
... label='fit: a=%5.3f, b=%5.3f, c=%5.3f' % tuple(popt))
>>> plt.xlabel('x')
>>> plt.ylabel('y')
>>> plt.legend()
>>> plt.show()
"""
if p0 is None:
# determine number of parameters by inspecting the function
sig = signature(f)
args = sig.parameters
if len(args) < 2:
raise ValueError("Unable to determine number of fit parameters.")
n = len(args) - 1
else:
p0 = np.atleast_1d(p0)
n = p0.size
lb, ub = prepare_bounds(bounds, n)
if p0 is None:
p0 = _initialize_feasible(lb, ub)
if method is None:
method = 'trf'
# NaNs cannot be handled
if check_finite:
ydata = np.asarray_chkfinite(ydata, float)
else:
ydata = np.asarray(ydata, float)
if isinstance(xdata, (list, tuple, np.ndarray)):
#should we be able to pass jax arrays
# `xdata` is passed straight to the user-defined `f`, so allow
# non-array_like `xdata`.
if check_finite:
xdata = np.asarray_chkfinite(xdata, float)
else:
xdata = np.asarray(xdata, float)
else:
raise ValueError('X needs arrays')
if ydata.size == 0:
raise ValueError("`ydata` must not be empty!")
m = len(ydata)
xdims = xdata.ndim
if xdims == 1:
xlen = len(xdata)
else:
xlen = len(xdata[0])
if xlen != m:
print(xdata.shape, ydata.shape)
raise ValueError('X and Y data lengths dont match')
if data_mask is None:
none_mask = True
else:
none_mask = False
if self.flength is not None:
len_diff = self.flength - m
if data_mask is not None:
if len(data_mask) != m:
raise ValueError('Data mask doesnt match data lengths.')
else:
data_mask = np.ones(m, dtype=bool)
if len_diff > 0:
data_mask = np.concatenate([data_mask,
np.zeros(len_diff, dtype=bool)])
else:
len_diff = 0
data_mask = np.ones(m, dtype=bool)
if self.flength is not None:
if len_diff >= 0:
xdata, ydata = self.pad_fit_data(xdata, ydata, xdims, len_diff)
else:
print('Data length greater than fixed length. This means retracing will occur')
# Determine type of sigma
if sigma is not None:
if not isinstance(sigma, np.ndarray):
raise ValueError('Sigma must be numpy array.')
# if 1-D, sigma are errors, define transform = 1/sigma
ysize = ydata.size - len_diff
if sigma.shape == (ysize, ):
if len_diff > 0:
sigma = np.concatenate([sigma, np.ones([len_diff])])
transform = self.sigma_transform1d(sigma, data_mask)
# if 2-D, sigma is the covariance matrix,
# define transform = L such that L L^T = C
elif sigma.shape == (ysize, ysize):
try:
if len_diff >= 0:
sigma_padded = np.identity(m + len_diff)
sigma_padded[:m,:m] = sigma
sigma = sigma_padded
# scipy.linalg.cholesky requires lower=True to return L L^T = A
transform = self.sigma_transform2d(sigma, data_mask)
except:
raise ValueError("Probably:`sigma` must be positive definite.")
else:
print('sigma shape', sigma.shape)
print('y shape', ydata.shape, ydata.size)
print(len_diff)
raise ValueError("`sigma` has incorrect shape.")
else:
transform = None
if 'args' in kwargs:
# The specification for the model function `f` does not support
# additional arguments. Refer to the `curve_fit` docstring for
# acceptable call signatures of `f`.
raise ValueError("'args' is not a supported keyword argument.")
if 'max_nfev' not in kwargs:
kwargs['max_nfev'] = kwargs.pop('maxfev', None)
st = time.time()
if timeit:
jnp_xdata = jnp.array(np.copy(xdata)).block_until_ready()
jnp_ydata = jnp.array(np.copy(ydata)).block_until_ready()
else:
jnp_xdata = jnp.array(np.copy(xdata))
jnp_ydata = jnp.array(np.copy(ydata))
ctime = time.time() - st
jnp_data_mask = jnp.array(data_mask, dtype=bool)
res = self.ls.least_squares(f, p0, jac=jac, xdata=jnp_xdata,
ydata=jnp_ydata, data_mask=jnp_data_mask,
transform=transform, bounds=bounds,
method=method, timeit=timeit, **kwargs)
if not res.success:
raise RuntimeError("Optimal parameters not found: " + res.message)
popt = res.x
st = time.time()
# ysize = len(res.fun)
ysize = m
cost = 2 * res.cost # res.cost is half sum of squares!
# Do Moore-Penrose inverse discarding zero singular values.
# _, s, VT = svd(res.jac, full_matrices=False)
outputs = self.covariance_svd(res.jac)
s, VT = [np.array(output) for output in outputs]
threshold = np.finfo(float).eps * max(res.jac.shape) * s[0]
s = s[s > threshold]
VT = VT[:s.size]
pcov = np.dot(VT.T / s**2, VT)
return_full = False
warn_cov = False
if pcov is None:
# indeterminate covariance
pcov = zeros((len(popt), len(popt)), dtype=float)
pcov.fill(inf)
warn_cov = True
elif not absolute_sigma:
if ysize > p0.size:
s_sq = cost / (ysize - p0.size)
pcov = pcov * s_sq
else:
pcov.fill(inf)
warn_cov = True
if warn_cov:
warnings.warn('Covariance of the parameters could not be estimated',
category=OptimizeWarning)
# self.res = res
post_time = time.time() - st
if return_eval:
feval = f(jnp_xdata, *popt)
feval = np.array(feval)
if none_mask:
# data_mask = np.ndarray.astype(data_mask, bool)
return popt, pcov, feval[data_mask]
else:
return popt, pcov, feval
else:
#lower GPU memory usage
res.pop('jac')
res.pop('fun')
if return_full:
raise RuntimeError("Return full only works for LM")
# return popt, pcov, infodict, errmsg, ier
elif timeit:
return popt, pcov, res, post_time, ctime
else:
return popt, pcov