"""Generic interface for least-squares minimization."""
from warnings import warn
import numpy as np
import time
from typing import Callable, Optional, Tuple, Union, Sequence, List, Any
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, jacfwd
from jax.scipy.linalg import solve_triangular as jax_solve_triangular
from jaxfit.trf import TrustRegionReflective
from jaxfit.loss_functions import LossFunctionsJIT
from jaxfit.common_scipy import EPS, in_bounds, make_strictly_feasible
TERMINATION_MESSAGES = {
-1: "Improper input parameters status returned from `leastsq`",
0: "The maximum number of function evaluations is exceeded.",
1: "`gtol` termination condition is satisfied.",
2: "`ftol` termination condition is satisfied.",
3: "`xtol` termination condition is satisfied.",
4: "Both `ftol` and `xtol` termination conditions are satisfied."
}
[docs]def prepare_bounds(bounds, n) -> Tuple[np.ndarray, np.ndarray]:
"""Prepare bounds for optimization.
This function prepares the bounds for the optimization by ensuring that
they are both 1-D arrays of length `n`. If either bound is a scalar, it is
resized to an array of length `n`.
Parameters
----------
bounds : Tuple[np.ndarray, np.ndarray]
The lower and upper bounds for the optimization.
n : int
The length of the bounds arrays.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The prepared lower and upper bounds arrays.
"""
lb, ub = [np.asarray(b, dtype=float) for b in bounds]
if lb.ndim == 0:
lb = np.resize(lb, n)
if ub.ndim == 0:
ub = np.resize(ub, n)
return lb, ub
[docs]def check_tolerance(ftol: float,
xtol: float,
gtol: float,
method: str
) -> Tuple[float, float, float]:
"""Check and prepare tolerance values for optimization.
This function checks the tolerance values for the optimization and
prepares them for use. If any of the tolerances is `None`, it is set to
0. If any of the tolerances is lower than the machine epsilon, a warning
is issued and the tolerance is set to the machine epsilon. If all
tolerances are lower than the machine epsilon, a `ValueError` is raised.
Parameters
----------
ftol : float
The tolerance for the optimization function value.
xtol : float
The tolerance for the optimization variable values.
gtol : float
The tolerance for the optimization gradient values.
method : str
The name of the optimization method.
Returns
-------
Tuple[float, float, float]
The prepared tolerance values.
"""
def check(tol: float, name: str) -> float:
if tol is None:
tol = 0
elif tol < EPS:
warn("Setting `{}` below the machine epsilon ({:.2e}) effectively "
"disables the corresponding termination condition."
.format(name, EPS))
return tol
ftol = check(ftol, "ftol")
xtol = check(xtol, "xtol")
gtol = check(gtol, "gtol")
if ftol < EPS and xtol < EPS and gtol < EPS:
raise ValueError("At least one of the tolerances must be higher than "
"machine epsilon ({:.2e}).".format(EPS))
return ftol, xtol, gtol
[docs]def check_x_scale(x_scale: Union[str, Sequence[float]],
x0: Sequence[float]
) -> Union[str, Sequence[float]]:
"""Check and prepare the `x_scale` parameter for optimization.
This function checks and prepares the `x_scale` parameter for the
optimization. `x_scale` can either be 'jac' or an array_like with positive
numbers. If it's 'jac' the jacobian is used as the scaling.
Parameters
----------
x_scale : Union[str, Sequence[float]]
The scaling for the optimization variables.
x0 : Sequence[float]
The initial guess for the optimization variables.
Returns
-------
Union[str, Sequence[float]]
The prepared `x_scale` parameter.
"""
if isinstance(x_scale, str) and x_scale == 'jac':
return x_scale
try:
x_scale = np.asarray(x_scale, dtype=float)
valid = np.all(np.isfinite(x_scale)) and np.all(x_scale > 0)
except (ValueError, TypeError):
valid = False
if not valid:
raise ValueError("`x_scale` must be 'jac' or array_like with "
"positive numbers.")
if x_scale.ndim == 0:
x_scale = np.resize(x_scale, x0.shape)
if x_scale.shape != x0.shape:
raise ValueError("Inconsistent shapes between `x_scale` and `x0`.")
return x_scale
"""Wraps the given function such that a masked jacfwd is performed on it
thereby giving the autodiff jacobian."""
[docs]class AutoDiffJacobian():
"""Wraps the residual fit function such that a masked jacfwd is performed
on it. thereby giving the autodiff Jacobian. This needs to be a class since
we need to maintain in memory three different versions of the Jacobian.
"""
[docs] def create_ad_jacobian(self,
func: Callable,
num_args: int,
masked: bool = True
) -> Callable:
"""Creates a function that returns the autodiff jacobian of the
residual fit function. The Jacobian of the residual fit function is
equivalent to the Jacobian of the fit function.
Parameters
----------
func : Callable
The function to take the jacobian of.
num_args : int
The number of arguments the function takes.
masked : bool, optional
Whether to use a masked jacobian, by default True
Returns
-------
Callable
The function that returns the autodiff jacobian of the given
function.
"""
# create a list of argument indices for the wrapped function which
# will correspond to the arguments of the residual fit function and
# will be past to JAX's jacfwd function.
arg_list = [4 + i for i in range(num_args)]
@jit
def wrap_func(*all_args: List[Any]) -> jnp.ndarray:
"""Wraps the residual fit function such that it can be passed to the
jacfwd function. Jacfwd requires the function to a single list
of arguments.
"""
xdata, ydata, data_mask, atransform = all_args[:4]
args = all_args[4:]
return func(args, xdata, ydata, data_mask, atransform)
@jit
def jac_func(args: List[float],
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray
) -> jnp.ndarray:
"""Returns the jacobian. Places all the residual fit function
arguments into a single list for the wrapped residual fit function.
Then calls the jacfwd function on the wrapped function with the
the arglist of the arguments to differentiate with respect to which
is only the arguments of the original fit function.
"""
fixed_args = [xdata, ydata, data_mask, atransform]
all_args = [*fixed_args, *args]
jac_fwd = jacfwd(wrap_func, argnums=arg_list)(*all_args)
return jnp.array(jac_fwd)
@jit
def masked_jac(args: List[float],
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray
) -> jnp.ndarray:
"""Returns the masked jacobian."""
Jt = jac_func(args, xdata, ydata, data_mask, atransform)
J = jnp.where(data_mask, Jt, 0).T
return jnp.atleast_2d(J)
@jit
def no_mask_jac(args: List[float],
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray
) -> jnp.ndarray:
"""Returns the unmasked jacobian."""
J = jac_func(args, xdata, ydata, data_mask, atransform).T
return jnp.atleast_2d(J)
if masked:
self.jac = masked_jac
else:
self.jac = no_mask_jac
return self.jac
[docs]class LeastSquares():
def __init__(self):
super().__init__() # not sure if this is needed
self.trf = TrustRegionReflective()
self.ls = LossFunctionsJIT()
#initialize jacobian to None and f to a dummy function
self.f = lambda x: None
self.jac = None
# need a separate instance of the autodiff class for each of the
# the different sigma/covariance cases
self.adjn = AutoDiffJacobian()
self.adj1d = AutoDiffJacobian()
self.adj2d = AutoDiffJacobian()
[docs] def least_squares(self,
fun: Callable,
x0: np.ndarray,
jac: Optional[Callable] = None,
bounds: Tuple[np.ndarray, np.ndarray] = (-np.inf, np.inf),
method: str = 'trf',
ftol: float = 1e-8,
xtol: float = 1e-8,
gtol: float = 1e-8,
x_scale: Union[str, np.ndarray, float] = 1.0,
loss: str = 'linear',
f_scale: float = 1.0,
diff_step=None,
tr_solver=None,
tr_options={},
jac_sparsity=None,
max_nfev: Optional[float] = None,
verbose: int = 0,
xdata: Optional[jnp.ndarray] = None,
ydata: Optional[jnp.ndarray] = None,
data_mask: Optional[jnp.ndarray] = None,
transform: Optional[jnp.ndarray] = None,
timeit: bool = False,
args=(),
kwargs={}):
if data_mask is None and ydata is not None:
data_mask = jnp.ones(len(ydata), dtype=bool)
if loss not in self.ls.IMPLEMENTED_LOSSES and not callable(loss):
raise ValueError("`loss` must be one of {0} or a callable."
.format(self.ls.IMPLEMENTED_LOSSES.keys()))
if method not in ['trf']:
raise ValueError("`method` must be 'trf")
if jac not in [None] and not callable(jac):
raise ValueError("`jac` must be None or "
"callable.")
if verbose not in [0, 1, 2]:
raise ValueError("`verbose` must be in [0, 1, 2].")
if len(bounds) != 2:
raise ValueError("`bounds` must contain 2 elements.")
if max_nfev is not None and max_nfev <= 0:
raise ValueError("`max_nfev` must be None or positive integer.")
if np.iscomplexobj(x0):
raise ValueError("`x0` must be real.")
x0 = np.atleast_1d(x0).astype(float)
if x0.ndim > 1:
raise ValueError("`x0` must have at most 1 dimension.")
self.n = len(x0)
lb, ub = prepare_bounds(bounds, x0.shape[0])
if lb.shape != x0.shape or ub.shape != x0.shape:
raise ValueError("Inconsistent shapes between bounds and `x0`.")
if np.any(lb >= ub):
raise ValueError("Each lower bound must be strictly less than each "
"upper bound.")
if not in_bounds(x0, lb, ub):
raise ValueError("`x0` is infeasible.")
x_scale = check_x_scale(x_scale, x0)
ftol, xtol, gtol = check_tolerance(ftol, xtol, gtol, method)
x0 = make_strictly_feasible(x0, lb, ub)
if xdata is not None and ydata is not None:
# checks to see if the fit function is the same. Can't directly
# compare the functions so we compare function code directly
func_update = self.f.__code__.co_code != fun.__code__.co_code
# if we are updating the fit function then we need to update the
# jacobian function as well
if func_update:
self.update_function(fun)
# this only updates the the jacobian if using autodiff (jac=None)
if jac is None:
self.autdiff_jac(jac)
# if using an analytical jacobian
if jac is not None:
# if we are in the first function call
if self.jac is None:
self.wrap_jac(jac)
elif self.jac.__code__.co_code != jac.__code__.co_code:
# checks to see if the jacobian function is the same (see
# func_update for why no direct comparing of the functions)
# if it's a different Jacobian we need to rewrap it
self.wrap_jac(jac)
elif self.jac is not None and not func_update:
self.autdiff_jac(jac)
# determines the correct residual function and jacobian to use
# depending on whether data uncertainty transform is None, 1D, or 2D
if transform is None:
rfunc = self.func_none
jac_func = self.jac_none
elif transform.ndim == 1:
rfunc = self.func_1d
jac_func = self.jac_1d
else:
rfunc = self.func_2d
jac_func = self.jac_2d
else:
# this if/else is to maintain compatibility with the SciPy suite of tests
# which assume the residual function contains the fit data which is not
# the case for JAXFit due to how we've made the residual function
# function to be compatible with JAX JIT compilation
def wrap_func(fargs, xdata, ydata, data_mask, atransform):
return jnp.atleast_1d(fun(fargs, *args, **kwargs))
def wrap_jac(fargs, xdata, ydata, data_mask, atransform):
return jnp.atleast_2d(jac(fargs, *args, **kwargs))
rfunc = wrap_func
if jac is None:
adj = AutoDiffJacobian()
jac_func = adj.create_ad_jacobian(wrap_func, self.n, masked=False)
else:
jac_func = wrap_jac
f0 = rfunc(x0, xdata, ydata, data_mask, transform)
J0 = jac_func(x0, xdata, ydata, data_mask, transform)
if f0.ndim != 1:
raise ValueError("`fun` must return at most 1-d array_like. "
"f0.shape: {0}".format(f0.shape))
if not np.all(np.isfinite(f0)):
raise ValueError("Residuals are not finite in the initial point.")
n = x0.size
m = f0.size
if J0 is not None:
if J0.shape != (m, n):
raise ValueError(
"The return value of `jac` has wrong shape: expected {0}, "
"actual {1}.".format((m, n), J0.shape))
if data_mask is None:
data_mask = jnp.ones(m)
loss_function = self.ls.get_loss_function(loss)
if callable(loss):
rho = loss_function(f0, f_scale, data_mask=data_mask)
if rho.shape != (3, m):
raise ValueError("The return value of `loss` callable has wrong "
"shape.")
initial_cost_jnp = self.trf.calculate_cost(rho, data_mask)
elif loss_function is not None:
initial_cost_jnp = loss_function(f0, f_scale, data_mask=data_mask,
cost_only=True)
else:
initial_cost_jnp = self.trf.default_loss_func(f0)
initial_cost = np.array(initial_cost_jnp)
result = self.trf.trf(rfunc, xdata, ydata, jac_func, data_mask,
transform, x0, f0, J0, lb, ub, ftol, xtol,
gtol, max_nfev, f_scale, x_scale, loss_function,
tr_options.copy(), verbose, timeit)
result.message = TERMINATION_MESSAGES[result.status]
result.success = result.status > 0
if verbose >= 1:
print(result.message)
print("Function evaluations {0}, initial cost {1:.4e}, final cost "
"{2:.4e}, first-order optimality {3:.2e}."
.format(result.nfev, initial_cost, result.cost,
result.optimality))
return result
[docs] def autdiff_jac(self, jac: None) -> None:
"""We do this for all three sigma transformed functions such
that if sigma is changed from none to 1D to covariance sigma then no
retracing is needed.
Parameters
----------
jac : None
Passed in to maintain compatibility with the user defined Jacobian
function.
"""
self.jac_none = self.adjn.create_ad_jacobian(self.func_none, self.n)
self.jac_1d = self.adj1d.create_ad_jacobian(self.func_1d, self.n)
self.jac_2d = self.adj2d.create_ad_jacobian(self.func_2d, self.n)
# jac is
self.jac = jac
[docs] def update_function(self, func: Callable) -> None:
"""Wraps the given fit function to be a residual function using the
data. The wrapped function is in a JAX JIT compatible format which
is purely functional. This requires that both the data mask and the
uncertainty transform are passed to the function. Even for the case
where the data mask is all True and the uncertainty transform is None
we still need to pass these arguments to the function due JAX's
functional nature.
Parameters
----------
func : Callable
The fit function to wrap.
Returns
-------
None
"""
@jit
def masked_residual_func(args: List[float],
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray
) -> jnp.ndarray:
"""Compute the residual of the function evaluated at `args` with
respect to the data.
This function computes the residual of the user fit function
evaluated at `args` with respect to the data `(xdata, ydata)`,
masked by `data_mask`. The residual is defined as the difference
between the function evaluation and the data. The masked residual
is obtained by setting the residual to 0 wherever the corresponding
element of `data_mask` is 0.
Parameters
----------
args : List[float]
The parameters of the function.
xdata : jnp.ndarray
The independent variable data.
ydata : jnp.ndarray
The dependent variable data.
data_mask : jnp.ndarray
The mask for the data.
Returns
-------
jnp.ndarray
The masked residual of the function evaluated at `args` with respect to the data.
"""
func_eval = func(xdata, *args) - ydata
return jnp.where(data_mask, func_eval, 0)
# need to define a separate function for each of the different
# sigma/covariance cases as the uncertainty transform is different
# for each case. In future could remove the no transfore bit by setting
# the uncertainty transform to all ones in the case where there is no
# uncertainty transform.
@jit
def func_no_transform(args: List[float],
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray
) -> jnp.ndarray:
"""The residual function when there is no uncertainty transform.
The atranform argument is not used in this case, but is included
for consistency with the other cases."""
return masked_residual_func(args, xdata, ydata, data_mask)
@jit
def func_1d_transform(args: List[float],
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray
) -> jnp.ndarray:
"""The residual function when there is a 1D uncertainty transform,
that is when only the diagonal elements of the inverse covariance
matrix are used."""
return atransform * masked_residual_func(args, xdata,
ydata, data_mask)
@jit
def func_2d_transform(args: List[float],
xdata: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray
) -> jnp.ndarray:
"""The residual function when there is a 2D uncertainty transform,
that is when the full covariance matrix is given."""
f = masked_residual_func(args, xdata, ydata, data_mask)
return jax_solve_triangular(atransform, f, lower=True)
self.func_none = func_no_transform
self.func_1d = func_1d_transform
self.func_2d = func_2d_transform
self.f = func
[docs] def wrap_jac(self, jac: Callable) -> None:
"""Wraps an user defined Jacobian function to allow for data masking
and uncertainty transforms. The wrapped function is in a JAX JIT
compatible format which is purely functional. This requires that both
the data mask and the uncertainty transform are passed to the function.
Using an analytical Jacobian of the fit function is equivalent to
the Jacobian of the residual function.
Also note that the analytical Jacobian doesn't require the independent
ydata, but we still need to pass it to the function to maintain
compatibility with autdiff version which does require the ydata.
Parameters
----------
jac : Callable
The Jacobian function to wrap.
Returns
-------
jnp.ndarray
The masked Jacobian of the function evaluated at `args` with respect to the data.
"""
@jit
def jac_func(coords: jnp.ndarray,
args: List[float]
) -> jnp.ndarray:
jac_fwd = jac(coords, *args)
return jnp.array(jac_fwd)
@jit
def masked_jac(coords: jnp.ndarray,
args: List[float],
data_mask: jnp.ndarray
) -> jnp.ndarray:
"""Compute the wrapped Jacobian but masks out the padded elements
with 0s"""
Jt = jac_func(coords, args)
return jnp.where(data_mask, Jt, 0).T
@jit
def jac_no_transform(args: List[float],
coords: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray
) -> jnp.ndarray:
"""The wrapped Jacobian function when there is no
uncertainty transform."""
return jnp.atleast_2d(masked_jac(coords, args, data_mask))
@jit
def jac_1d_transform(args: List[float],
coords: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray
) -> jnp.ndarray:
"""The wrapped Jacobian function when there is a 1D uncertainty
transform, that is when only the diagonal elements of the inverse
covariance matrix are used."""
J = masked_jac(coords, args, data_mask)
return jnp.atleast_2d(atransform[:, jnp.newaxis] * jnp.asarray(J))
@jit
def jac_2d_transform(args: List[float],
coords: jnp.ndarray,
ydata: jnp.ndarray,
data_mask: jnp.ndarray,
atransform: jnp.ndarray
) -> jnp.ndarray:
"""The wrapped Jacobian function when there is a 2D uncertainty
transform, that is when the full covariance matrix is given."""
J = masked_jac(coords, args, data_mask)
return jnp.atleast_2d(jax_solve_triangular(atransform,
jnp.asarray(J),
lower=True))
# we need all three versions of the Jacobian function to allow for
# changing the sigma transform from none to 1D to 2D without having
# to retrace the function
self.jac_none = jac_no_transform
self.jac_1d = jac_1d_transform
self.jac_2d = jac_2d_transform
self.jac = jac