Source code for lightcurvelynx.utils.extrapolate

"""Functions for extrapolating flux past the end of a model's range of valid
phases or wavelengths using flux = f(time, wavelengths).
"""

import abc

import numpy as np

from lightcurvelynx.astro_utils.mag_flux import flux2mag, mag2flux


[docs] class FluxExtrapolationModel(abc.ABC): """The base class for the flux extrapolation methods. Attributes ---------- nfit : int The number of points to be used for extrapolation. (Default is 1) """ @abc.abstractmethod def __init__(self): # By default we compute extrapolation using the last valid point. This can be changed # by setting nfit to a larger number (e.g. in LinearFit).
[docs] self.nfit = 1
@abc.abstractmethod def _extrapolate(self, last_values, last_fluxes, query_values): """Evaluate the extrapolation given the last valid points(s) and a list of new query points. Note ---- This function does not care which axis is being extrapolated. The returned values are always len(query_values) x len(last_fluxes) and may need to be transposed by the calling function. Parameters ---------- last_values : float or np.ndarray The last valid value or a length L array of the last valid values along the extrapolation axis at which the fluxes were predicted (e.g., wavelength in AA or time in days). last_fluxes : numpy.ndarray A length N x L array of the flux values at the last valid time or wavelength (in nJy). query_values : numpy.ndarray A length M array of values along the extrapolation axis (times in days or wavelengths in AA) at which to extrapolate. Returns ------- flux : numpy.ndarray A N x M matrix of extrapolated values. Where M is the number of query points and N is the number of flux values at the last valid point. """ raise NotImplementedError("Subclasses must implement this method.") # pragma: no cover
[docs] def extrapolate_time(self, last_times, last_fluxes, query_times): """Extrapolate along the time axis. Parameters ---------- last_times : float or np.ndarray A length T1 array of the last valid times (in days) at which the fluxes were predicted. last_fluxes : numpy.ndarray A length T1 x W array of the last valid flux values at each wavelength at the last valid times (in nJy). query_times : numpy.ndarray A length T2 array of the query times (in days) at which to extrapolate. Returns ------- flux : numpy.ndarray A T2 x W matrix of extrapolated values. """ last_fluxes = last_fluxes.T return self._extrapolate(last_times, last_fluxes, query_times).T
[docs] def extrapolate_wavelength(self, last_waves, last_fluxes, query_waves): """Extrapolate along the wavelength axis. Parameters ---------- last_waves : float or np.ndarray A length W1 array of the last valid wavelengths (in AA) at which the fluxes were predicted. last_fluxes : numpy.ndarray A length T x W1 array of the last valid flux values at each time at the last valid wavelength (in nJy). query_waves : numpy.ndarray A length W2 array of the query wavelengths (in AA) at which to extrapolate. Returns ------- flux : numpy.ndarray A T x W2 matrix of extrapolated values. """ # We transpose the result to turn the W2 x T matrix into a T x W matrix. return self._extrapolate(last_waves, last_fluxes, query_waves)
[docs] class ZeroPadding(FluxExtrapolationModel): """Extrapolate by zero padding the results.""" def __init__(self): super().__init__() def _extrapolate(self, last_values, last_fluxes, query_values): """Evaluate the extrapolation given the last valid points(s) and a list of new query points. Note ---- This function does not care which axis is being extrapolated. The returned values are always len(query_values) x len(last_fluxes) and may need to be transposed by the calling function. Parameters ---------- last_values : float or np.ndarray The last valid value along the extrapolation axis at which the flux was predicted (e.g., wavelength in AA or time in days). last_fluxes : numpy.ndarray A length N array of the flux values at the last valid time or wavelength (in nJy). query_values : numpy.ndarray A length M array of values along the extrapolation axis (times in days or wavelengths in AA) at which to extrapolate. Returns ------- flux : numpy.ndarray A N x M matrix of extrapolated values. Where M is the number of query points and N is the number of flux values at the last valid point. """ N_len = len(last_fluxes) M_len = len(query_values) return np.zeros((N_len, M_len))
[docs] class ConstantPadding(FluxExtrapolationModel): """Extrapolate using a constant value in nJy. Attributes ---------- value : float The value (in nJy) to use for the extrapolation. """ def __init__(self, value=0.0): super().__init__() if value < 0: raise ValueError("Extrapolation value must be positive.")
[docs] self.value = value
def _extrapolate(self, last_values, last_fluxes, query_values): """Evaluate the extrapolation given the last valid points(s) and a list of new query points. Note ---- This function does not care which axis is being extrapolated. The returned values are always len(query_values) x len(last_fluxes) and may need to be transposed by the calling function. Parameters ---------- last_values : float or np.ndarray The last valid value along the extrapolation axis at which the flux was predicted (e.g., wavelength in AA or time in days). last_fluxes : numpy.ndarray A length N array of the flux values at the last valid time or wavelength (in nJy). query_values : numpy.ndarray A length M array of values along the extrapolation axis (times in days or wavelengths in AA) at which to extrapolate. Returns ------- flux : numpy.ndarray A N x M matrix of extrapolated values. Where M is the number of query points and N is the number of flux values at the last valid point. """ N_len = len(last_fluxes) M_len = len(query_values) return np.full((N_len, M_len), self.value)
[docs] class LastValue(FluxExtrapolationModel): """Extrapolate using the last valid value along the desired axis.""" def __init__(self): super().__init__() def _extrapolate(self, last_values, last_fluxes, query_values): """Evaluate the extrapolation given the last valid points(s) and a list of new query points. Note ---- This function does not care which axis is being extrapolated. The returned values are always len(query_values) x len(last_fluxes) and may need to be transposed by the calling function. Parameters ---------- last_values : float or np.ndarray The last valid value along the extrapolation axis at which the flux was predicted (e.g., wavelength in AA or time in days). last_fluxes : numpy.ndarray A length N array of the flux values at the last valid time or wavelength (in nJy). query_values : numpy.ndarray A length M array of values along the extrapolation axis (times in days or wavelengths in AA) at which to extrapolate. Returns ------- flux : numpy.ndarray A N x M matrix of extrapolated values. Where M is the number of query points and N is the number of flux values at the last valid point. """ last_fluxes = np.asarray(last_fluxes).reshape(-1) return np.tile(last_fluxes[:, np.newaxis], (1, len(query_values)))
[docs] class LinearDecay(FluxExtrapolationModel): """Linear decay of the flux using the last valid point(s) down to zero. Attributes ---------- decay_width : float or np.ndarray The width of the decay region in Angstroms. The flux is linearly decreased to zero over this range. """ def __init__(self, decay_width=100.0): super().__init__() if decay_width <= 0: raise ValueError("decay_width must be positive.")
[docs] self.decay_width = decay_width
def _extrapolate(self, last_values, last_fluxes, query_values): """Evaluate the extrapolation given the last valid points(s) and a list of new query points. Note ---- This function does not care which axis is being extrapolated. The returned values are always len(query_values) x len(last_fluxes) and may need to be transposed by the calling function. Parameters ---------- last_values : float or np.ndarray The last valid value along the extrapolation axis at which the flux was predicted (e.g., wavelength in AA or time in days). last_fluxes : numpy.ndarray A length N array of the flux values at the last valid time or wavelength (in nJy). query_values : numpy.ndarray A length M array of values along the extrapolation axis (times in days or wavelengths in AA) at which to extrapolate. Returns ------- flux : numpy.ndarray A N x M matrix of extrapolated values. Where M is the number of query points and N is the number of flux values at the last valid point. """ last_fluxes = np.asarray(last_fluxes).reshape(-1) query_values = np.asarray(query_values) dist = np.abs(query_values - last_values) multiplier = np.clip(1.0 - (dist / self.decay_width), 0.0, 1.0) flux = last_fluxes[:, np.newaxis] * multiplier[np.newaxis, :] return flux
[docs] class LinearDecayOnMag(FluxExtrapolationModel): """Linear decay of the converted magnitude using the last valid point(s) with a fixed decay rate down to a specific magnitude threshold. This is generally used for extrapolating in the time axis. Attributes ---------- decay_rate : float or np.ndarray The rate of the decay region in days. The magnitude is linearly decreased to the mag_thres over this range. mag_thres : float or np.ndarray The mag threshold for the linear decay extraplation. Fluxes are capped at this value for time/wavelength beyond this value. """ def __init__(self, decay_rate=0.02, mag_thres=40.0): super().__init__() if decay_rate <= 0: raise ValueError("decay_rate must be positive.")
[docs] self.decay_rate = decay_rate
[docs] self.mag_thres = mag_thres
def _extrapolate(self, last_values, last_fluxes, query_values): """Evaluate the extrapolation given the last valid points(s) and a list of new query points. Note ---- This function does not care which axis is being extrapolated. The returned values are always len(query_values) x len(last_fluxes) and may need to be transposed by the calling function. Parameters ---------- last_values : float or np.ndarray The last valid value along the extrapolation axis at which the flux was predicted (e.g., wavelength in AA or time in days). last_fluxes : numpy.ndarray A length N array of the flux values at the last valid time or wavelength (in nJy). query_values : numpy.ndarray A length M array of values along the extrapolation axis (times in days or wavelengths in AA) at which to extrapolate. Returns ------- flux : numpy.ndarray A N x M matrix of extrapolated values. Where M is the number of query points and N is the number of flux values at the last valid point. """ last_fluxes = np.asarray(last_fluxes).reshape(-1) last_fluxes = np.clip(last_fluxes, mag2flux(self.mag_thres), None) last_mags = flux2mag(last_fluxes) query_values = np.asarray(query_values) dist = np.abs(query_values - last_values) mag = np.clip(last_mags[:, np.newaxis] + dist[np.newaxis, :] * self.decay_rate, None, self.mag_thres) return mag2flux(mag)
[docs] class ExponentialDecay(FluxExtrapolationModel): """Exponential decay of the flux using the last valid point(s) down to zero. f(t, λ) = f(t, λ_last) * exp(- rate * \\|λ - λ_last\\|) Attributes ---------- rate : float The decay rate in the exponential function. """ def __init__(self, rate): super().__init__() if rate < 0: raise ValueError("Decay rate must be zero or positive.")
[docs] self.rate = rate
def _extrapolate(self, last_values, last_fluxes, query_values): """Evaluate the extrapolation given the last valid points(s) and a list of new query points. Note ---- This function does not care which axis is being extrapolated. The returned values are always len(query_values) x len(last_fluxes) and may need to be transposed by the calling function. Parameters ---------- last_values : float or np.ndarray The last valid value along the extrapolation axis at which the flux was predicted (e.g., wavelength in AA or time in days). last_fluxes : numpy.ndarray A length N array of the flux values at the last valid time or wavelength (in nJy). query_values : numpy.ndarray A length M array of values along the extrapolation axis (times in days or wavelengths in AA) at which to extrapolate. Returns ------- flux : numpy.ndarray A N x M matrix of extrapolated values. Where M is the number of query points and N is the number of flux values at the last valid point. """ last_fluxes = np.asarray(last_fluxes).reshape(-1) query_values = np.asarray(query_values) dist = np.abs(query_values - last_values) multiplier = np.exp(-self.rate * dist) flux = last_fluxes[:, np.newaxis] * multiplier[np.newaxis, :] return flux
def _bin_rows_median(last_fluxes, nbin, *, nan_safe=True): """Bin the input fluxes on the first axis given number of bins and return the median values of each bin. This is used for binning the last fluxes to avoid extrapolating to extreme values. Parameters ---------- last_fluxes : np.ndarray A N x T array of the input fluxes to be binned. nbin : int Number of bins along the first axis. nan_safe : bool, optional If True, use np.nanmedian (ignore NaNs). If False, use np.median. Returns ------- binned_fluxes : np.ndarray A nbin x T array of the binned fluxes. """ last_fluxes = np.asarray(last_fluxes) N, T = last_fluxes.shape if nbin > N: raise ValueError("nbin must be smaller or equal to N") # Bin edges that evenly partition rows edges = np.linspace(0, N, nbin + 1, dtype=int) binned_fluxes = np.empty((nbin, T), dtype=float) for b in range(nbin): lo, hi = edges[b], edges[b + 1] chunk = last_fluxes[lo:hi] binned_fluxes[b] = np.nanmedian(chunk, axis=0) if nan_safe else np.median(chunk, axis=0) return binned_fluxes
[docs] class LinearFit(FluxExtrapolationModel): """Linear extrapolation based on a linear fit to the last few points. Parameters ---------- nfit : int The number of points to be used for extrapolation. (Default is 5) nbin : int The number of bins to be used to bin the last fluxes. This can be used to avoid extrapolating to extreme values when models are not well-behaved in smaller bins. """ def __init__(self, nfit=5, nbin=None): super().__init__()
[docs] self.nfit = nfit
[docs] self.nbin = nbin
def _extrapolate(self, last_values, last_fluxes, query_values): """Evaluate the extrapolation given the last valid points(s) and a list of new query points. Parameters ---------- last_values : np.ndarray A T elements array of the last values along the extrapolation axis at which the flux was predicted (e.g., wavelength in AA or time in days). last_fluxes : ndarray A length N x T matrix of the flux values at the last valid time or wavelength (in nJy). query_values : ndarray A length M array of values along the extrapolation axis (times in days or wavelengths in AA) at which to extrapolate. Returns ------- flux : ndarray A N x M matrix of extrapolated values. Where M is the number of query points and N is the number of flux values at the last valid point. """ if len(last_values) <= 1: raise ValueError("Need at least two points to extrapolate using this method.") N = last_fluxes.shape[0] if self.nbin is None: binned_fluxes = last_fluxes else: # guard: can't have more bins than rows nbin = int(min(self.nbin, N)) binned_fluxes = _bin_rows_median(last_fluxes, nbin=nbin, nan_safe=True) A = np.column_stack([last_values, np.ones_like(last_values)]) B = np.array(binned_fluxes, dtype=float, copy=True).T coeffs = np.linalg.lstsq(A, B, rcond=None)[0] slope, intercept = coeffs # (nbin, M) flux_binned = slope[:, None] * query_values[None, :] + intercept[:, None] flux_binned = np.clip(flux_binned, 0.0, None) # Expand back to (N, M): row i gets its bin's curve if self.nbin is None: flux = flux_binned else: row_to_bin = (np.arange(N) * nbin) // N flux = flux_binned[row_to_bin] return flux
[docs] class LinearFitOnMag(FluxExtrapolationModel): """Linear extrapolation based on a linear fit to the coverted magnitude of the last few points. Parameters ---------- nfit : int The number of points to be used for extrapolation. (Default is 5) nbin : int The number of bins to be used to bin the last fluxes. This can be used to avoid extrapolating to extreme values when models are not well-behaved in smaller bins. """ def __init__(self, nfit=5, nbin=None): super().__init__()
[docs] self.nfit = nfit
[docs] self.nbin = nbin
def _extrapolate(self, last_values, last_fluxes, query_values): """Evaluate the extrapolation given the last valid points(s) and a list of new query points. Parameters ---------- last_values : np.ndarray A T elements array of the last values along the extrapolation axis at which the flux was predicted (e.g., wavelength in AA or time in days). last_fluxes : ndarray A length N x T matrix of the flux values at the last valid time or wavelength (in nJy). query_values : ndarray A length M array of values along the extrapolation axis (times in days or wavelengths in AA) at which to extrapolate. Returns ------- flux : ndarray A N x M matrix of extrapolated values. Where M is the number of query points and N is the number of flux values at the last valid point. """ if len(last_values) <= 1: raise ValueError("Need at least two points to extrapolate using this method.") N = last_fluxes.shape[0] last_fluxes = np.clip(last_fluxes, 1.0e-40, None) last_fluxes = flux2mag(last_fluxes) if self.nbin is None: binned_fluxes = last_fluxes else: # guard: can't have more bins than rows (otherwise some bins empty -> median NaN) nbin = int(min(self.nbin, N)) binned_fluxes = _bin_rows_median(last_fluxes, nbin=nbin, nan_safe=True) A = np.column_stack([last_values, np.ones_like(last_values)]) B = np.array(binned_fluxes, dtype=float, copy=True).T coeffs = np.linalg.lstsq(A, B, rcond=None)[0] slope, intercept = coeffs # (nbin, M) flux_binned = slope[:, None] * query_values[None, :] + intercept[:, None] flux_binned = np.clip(flux_binned, 0.0, None) if self.nbin is None: flux = flux_binned else: # Expand back to (N, M): row i gets its bin's curve row_to_bin = (np.arange(N) * nbin) // N flux = flux_binned[row_to_bin] return mag2flux(flux)