Source code for lightcurvelynx.utils.bicubic_interp

"""The BicubicInterpolator is used by SALT models.

It is adapted from sncosmo's BicubicInterpolator class (but implemented in JAX):
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/salt2utils.pyx
"""

try:
    import jax.numpy as jnp
    from jax import jit, vmap
except ImportError as err:  # pragma: no cover
    raise ImportError(
        "JAX is required to use the BicubicInterpolator class, please "
        "install with `pip install jax` or `conda install conda-forge::jax`"
    ) from err

from lightcurvelynx.utils.io_utils import read_grid_data


@jit
def _kernel_value(input_vals):
    """A vectorized (JAX-ized) form of the kernval method from:
    https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/salt2utils.pyx

    Uses:
        A = -0.5
        B = A + 2.0 = 1.5
        C = A + 3.0 = 2.5
    """
    x = jnp.abs(input_vals)

    # Start with the case of 1.0 <= x <= 2.0.
    result = -0.5 * (-4.0 + x * (8.0 + x * (-5.0 + x)))

    # Override cases where x < 1.0
    result = jnp.where(x < 1.0, x * x * (1.5 * x - 2.5) + 1.0, result)

    # Override cases where x > 2.0
    result = jnp.where(x > 2.0, 0.0, result)
    return result


def _compute_linear(ix, iy, ax, ay, z_vals):
    """Computes the linear interpolation (used for the edges). This operation is
    broken out into a function so it can be jit compiled and vmapped.

    Parameters
    ----------
    ix, iy : int
        The location of the current point in the array of z values.
    ax, ay : float
        The current point distance relative to the respective nodes in the z array.
    z_vals : JAX Array
        An 2d matrix of z values.

    Returns
    -------
    float
        The result of the linear interpolation at this point.
    """
    return (1.0 - ax) * ((1.0 - ay) * z_vals[ix][iy] + ay * z_vals[ix][iy + 1]) + ax * (
        (1.0 - ay) * z_vals[ix + 1][iy] + ay * z_vals[ix + 1][iy + 1]
    )


def _compute_cubic(ix, iy, dx, dy, z_vals):
    """Computes the bicubic interpolation (used for the center). This operation is
    broken out into a function so it can be jit compiled and vmapped.

    Parameters
    ----------
    ix, iy : int
        The location of the current point in the array of z values.
    dx, dy : float
        The current point distance relative to the respective nodes in the z array.
    z_vals : JAX Array
        An 2d matrix of z values.

    Returns
    -------
    float
        The result of the bicubic interpolation at this point.
    """
    wx = _kernel_value(jnp.asarray([dx - 1.0, dx, dx + 1.0, dx + 2.0]))
    wy = _kernel_value(jnp.asarray([dy - 1.0, dy, dy + 1.0, dy + 2.0]))

    return (
        wx[0]
        * (
            wy[0] * z_vals[ix - 1][iy - 1]
            + wy[1] * z_vals[ix - 1][iy]
            + wy[2] * z_vals[ix - 1][iy + 1]
            + wy[3] * z_vals[ix - 1][iy + 2]
        )
        + wx[1]
        * (
            wy[0] * z_vals[ix][iy - 1]
            + wy[1] * z_vals[ix][iy]
            + wy[2] * z_vals[ix][iy + 1]
            + wy[3] * z_vals[ix][iy + 2]
        )
        + wx[2]
        * (
            wy[0] * z_vals[ix + 1][iy - 1]
            + wy[1] * z_vals[ix + 1][iy]
            + wy[2] * z_vals[ix + 1][iy + 1]
            + wy[3] * z_vals[ix + 1][iy + 2]
        )
        + wx[3]
        * (
            wy[0] * z_vals[ix + 2][iy - 1]
            + wy[1] * z_vals[ix + 2][iy]
            + wy[2] * z_vals[ix + 2][iy + 1]
            + wy[3] * z_vals[ix + 2][iy + 2]
        )
    )


@jit
[docs] def expand_to_cross_products(x_vals, y_vals): """Create the unraveled arrays representing the grid of points from each access. Parameters ---------- x_vals : array-like A length N array of the x values. y_vals : array-like A length M array of the y values. Returns ------- all_x : JAX array A length N * M array of the x values. all_y : JAX array A length N * M array of the y values. """ grids = jnp.meshgrid(x_vals, y_vals, indexing="ij") x_all = jnp.ravel(grids[0]) y_all = jnp.ravel(grids[1]) return x_all, y_all
[docs] class BicubicInterpolator: """An object that performs bicubic interpolation over a 2-d grid. Parameters ---------- x_vals : array-like The values along the x-axis of the grid. The values must be sorted and at regular step sizes. y_vals : array-like The values along the y-axis of the grid. The values must be sorted and at regular step sizes. z_vals : array-like The values along the z-axis of the grid. Attributes ---------- x_vals : BicubicAxis The values along the x-axis of the grid stored with precomputed information. y_vals : BicubicAxis The values along the y-axis of the grid stored with precomputed information. z_vals : JAX array The values along the z-axis of the grid. """ def __init__(self, x_vals, y_vals, z_vals): # Load and validate the x and y values.
[docs] self.x_vals = BicubicAxis(jnp.asarray(x_vals))
[docs] self.y_vals = BicubicAxis(jnp.asarray(y_vals))
# Load an validate the z values.
[docs] self.z_vals = jnp.asarray(z_vals)
if len(self.z_vals.shape) != 2: raise ValueError(f"z values should be a 2-d array. Found shape={self.z_vals.shape}") if self.z_vals.shape[0] != len(self.x_vals) or self.z_vals.shape[1] != len(self.y_vals): raise ValueError( f"z values wrong shape. Expected shape=({len(self.x_vals)}, {len(self.y_vals)})." f" Found shape={self.z_vals.shape}." ) self._compute_linear = vmap(jit(_compute_linear), in_axes=(0, 0, 0, 0, None)) self._compute_cubic = vmap(jit(_compute_cubic), in_axes=(0, 0, 0, 0, None)) @classmethod
[docs] def from_grid_file(cls, filename, scale_factor=1.0): """Load the grid data from an ASCII file and create a BicubicInterpolator. Parameters ---------- filename : str The name of the grid file. scale_factor : float A multiplicative scale factor for the z values. Default: 1.0 """ x_vals, y_vals, z_vals = read_grid_data(filename) z_vals *= scale_factor return BicubicInterpolator(x_vals, y_vals, z_vals)
[docs] def __call__(self, x_q, y_q): """Evaluate the bicubic interpolation at a grid of points. Parameters ---------- x_q : array-like The N-length array of x values. y_q : array-like The M-length array of y values. Returns ------- results : jaxlib.xla_extension.ArrayImpl An N x M array of interpolated values for each (x, y) pair. """ x_q = jnp.asarray(x_q) y_q = jnp.asarray(y_q) n_xq = len(x_q) n_yq = len(y_q) # Find the first index *before* each of the query values with # the n-1 index in each dimension mapped to n-2. ix = self.x_vals.find_indices(x_q) iy = self.y_vals.find_indices(y_q) ix_all, iy_all = expand_to_cross_products(ix, iy) # Compute the weights to use for linear interpolation and perform the linear interpolation. wx, wy = expand_to_cross_products( (x_q - self.x_vals.values[ix]) / (self.x_vals.values[ix + 1] - self.x_vals.values[ix]), (y_q - self.y_vals.values[iy]) / (self.y_vals.values[iy + 1] - self.y_vals.values[iy]), ) lin_res = self._compute_linear(ix_all, iy_all, wx, wy, self.z_vals) # Compute the cubic kernel weights for each point and the resulting # cubic interpolated values. dx, dy = expand_to_cross_products( (self.x_vals.values[ix] - x_q) / (self.x_vals.values[ix + 1] - self.x_vals.values[ix]), (self.y_vals.values[iy] - y_q) / (self.y_vals.values[iy + 1] - self.y_vals.values[iy]), ) quad_res = self._compute_cubic(ix_all, iy_all, dx, dy, self.z_vals) # Fill in the values from the different interpolations. Use 0.0 for anything out of # bounds. Use linear interpolation for anything along the edges. And use cubic # interpolation for the points in the middle. x_out, y_out = expand_to_cross_products( self.x_vals.out_of_bounds(x_q), self.y_vals.out_of_bounds(y_q), ) out_of_bounds = x_out | y_out x_edge, y_edge = expand_to_cross_products( (ix == 0) | (ix > self.x_vals.num_vals - 3), (iy == 0) | (iy > self.y_vals.num_vals - 3), ) on_edge = x_edge | y_edge results = jnp.where(out_of_bounds, 0.0, jnp.where(on_edge, lin_res, quad_res)) results = results.reshape(n_xq, n_yq) return results
[docs] class BicubicAxis: """A helper class that represents values for an axis of bicubic interpolation with restrictions on acceptable data to match the SALT2 and SALT3 models. Restrictions include: - Data must contain at least 3 values. - Data must be sorted. - Data must be spaced at regular steps. Attributes ---------- values : JAX Array The values of the range. min_val : float The starting value of the range. max_val : float The maximum value of the range. num_vals : int The number of values in the range. regular_steps : bool Indicates whether the axis uses regularly sized steps. step_size : float The step size of the range. """ def __init__(self, values): # Load and validate the y values.
[docs] self.values = jnp.asarray(values)
if len(self.values.shape) != 1: raise ValueError( f"The RegularRange values should be a 1-d array. Found shape={self.values.shape}." ) if len(self.values) < 3: raise ValueError( f"Insufficient points for RegularRange. Required >= 3. Found {len(self.values)}." ) if jnp.any(self.values[:-1] >= self.values[1:]): raise ValueError("The RegularRange values must be in sorted order.")
[docs] self.num_vals = len(self.values)
[docs] self.min_val = self.values[0]
[docs] self.max_val = self.values[-1]
# Check the step sizes are regular (allowing for inprecision when things are written as # floats). If it is regular we can use a faster algorithm to find the indices. # We pick and save a single core search function, so JAX does not encounter an "if" during # tracing (at the cost of another function call during find_indices). step_sizes = self.values[1:] - self.values[:-1] min_step = jnp.min(step_sizes) max_step = jnp.max(step_sizes) if max_step - min_step > 1e-6: self.regular_steps = False self.step_size = None self._initial_search = self._search_nonregular else: self.regular_steps = True self.step_size = 0.5 * min_step + 0.5 * max_step self._initial_search = self._search_regular
[docs] def __len__(self): return self.num_vals
[docs] def __str__(self): return f"BicubicAxis [{self.min_val},{self.max_val}]. step={self.step_size}. size={self.num_vals}"
def _search_regular(self, query_pts): """Do an initial search for the query points indices given regular steps.""" return jnp.floor((query_pts - self.min_val) / self.step_size).astype(int) def _search_nonregular(self, query_pts): """Do an initial search for the query points indices given non-regular steps.""" return jnp.searchsorted(self.values, query_pts, side="right") - 1
[docs] def find_indices(self, query_pts): """Finds the first index *before* each of the query values with the n - 1 index in each dimension mapped to n - 2:: values[idx[i]] <= query_pts[i] < values[idx[i] + 1] for all i where idx[i] > 0 and idx[i] < n - 2. Parameters ---------- query_pts : array-like The values of the query points. Returns ------- idx : JAX Array A pair of arrays with the indices. """ query_pts = jnp.asarray(query_pts) idx = self._initial_search(query_pts) idx = jnp.where(idx < 0, 0, idx) idx = jnp.where(idx >= self.num_vals - 1, self.num_vals - 2, idx) return idx
[docs] def out_of_bounds(self, values): """Compute a Boolean array of the values that are out of bounds. Parameters ---------- values : array-like The N values to test Returns ------- results : JAX array A length N array indicating whether each element is out of bounds. """ values = jnp.asarray(values) return (values < self.min_val) | (values > self.max_val)