JAXFit 2D Gaussian Example

Open In Colab

Installing and Importing

Make sure your runtime type is set to GPU rather than CPU. Then we install JAXFit with pip

[1]:
!pip install jaxfit
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jaxfit
  Downloading jaxfit-0.0.4-py3-none-any.whl (38 kB)
Requirement already satisfied: scipy>=1.7.0 in /usr/local/lib/python3.8/dist-packages (from jaxfit) (1.7.3)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (from jaxfit) (3.2.2)
Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from jaxfit) (1.21.6)
Requirement already satisfied: JAX>=0.3.7 in /usr/local/lib/python3.8/dist-packages (from jaxfit) (0.3.25)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from JAX>=0.3.7->jaxfit) (4.4.0)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.8/dist-packages (from JAX>=0.3.7->jaxfit) (3.3.0)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib->jaxfit) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->jaxfit) (1.4.4)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->jaxfit) (3.0.9)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib->jaxfit) (2.8.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib->jaxfit) (1.15.0)
Installing collected packages: jaxfit
Successfully installed jaxfit-0.0.4

Import JAXFit before importing JAX since we need JAXFit to set all the JAX computation to use 64 rather than 32 bit arrays.

[2]:
from jaxfit import CurveFit
import jax.numpy as jnp

Now let’s define a 2D Gaussian using jax.numpy. You can construct function just like numpy with a few small caveats.

[3]:
def rotate_coordinates2D(coords, theta):
    R = jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
                  [jnp.sin(theta), jnp.cos(theta)]])

    shape = coords[0].shape
    coords = jnp.stack([coord.flatten() for coord in coords])
    rcoords = R @ coords
    return [jnp.reshape(coord, shape) for coord in rcoords]


def gaussian2d(coords, n0, x0, y0, sigma_x, sigma_y, theta, offset):
    coords = [coords[0] - x0, coords[1] - y0] #translate first
    X, Y = rotate_coordinates2D(coords, theta)
    density = n0 * jnp.exp(-.5 * (X**2 / sigma_x**2 + Y**2 / sigma_y**2))
    return density + offset

Using the function we just created, we’ll simulate some synthetic fit data and show what it looks like.

[4]:
import numpy as np
import matplotlib.pyplot as plt
import time

def get_coordinates(width, height):
    x = np.linspace(0, width - 1, width)
    y = np.linspace(0, height - 1, height)
    X, Y = np.meshgrid(x, y)
    return X, Y


def get_gaussian_parameters(length):
  n0 = 1
  x0 = length / 2
  y0 = length / 2
  sigx = length / 6
  sigy = length / 8
  theta = np.pi / 3

  offset = .1 * n0
  params = [n0, x0, y0, sigx, sigy, theta, offset]
  return params

length = 500
XY_tuple = get_coordinates(length, length)

params = get_gaussian_parameters(length)
zdata = gaussian2d(XY_tuple, *params)
zdata += np.random.normal(0, .1, size=(length, length))

plt.imshow(zdata)
plt.show()
../_images/notebooks_JAXFit_2D_Gaussian_Demo_8_0.png

Now we initialize the JAXFit CurveFit object and then fit our synthetic data 100 times with a different random seed for each fit.

[8]:
from scipy.optimize import curve_fit

def get_random_float(low, high):
    delta = high - low
    return low + delta * np.random.random()

flat_data = zdata.flatten()
flat_XY_tuple = [coord.flatten() for coord in XY_tuple]
jcf = CurveFit()

loop = 100
times = []
stimes = []
for i in range(loop):
    seed = [val * get_random_float(.9, 1.2) for val in params]
    st = time.time()
    popt, pcov = jcf.curve_fit(gaussian2d, flat_XY_tuple, flat_data, p0=seed)
    times.append(time.time() - st)

popt2, pcov2 = curve_fit(gaussian2d, flat_XY_tuple, flat_data, p0=seed)

Let’s see how fast these fits were done and how they compare to SciPy in accuracy.

[9]:
print('Average fit time', np.mean(times[1:]))
print('JAXFit parameters', popt)
print('SciPy parameters', popt2)

plt.figure()
plt.plot(times[1:])
plt.xlabel('Fit Number')
plt.ylabel('Fit Speed (seconds)')
plt.show()
Average fit time 0.033657815721299916
JAXFit parameters [1.00018438e+00 2.50022831e+02 2.49971078e+02 8.32246867e+01
 6.25156007e+01 1.05193543e+00 9.98519546e-02]
SciPy parameters [1.00018438e+00 2.50022831e+02 2.49971078e+02 8.32246867e+01
 6.25156007e+01 1.05193543e+00 9.98519546e-02]
../_images/notebooks_JAXFit_2D_Gaussian_Demo_12_1.png

ooo so speedy

[7]:
# jcf = CurveFit(flength=length**2)