"""Functions for extrapolating flux past the end of a model's range of valid
phases or wavelengths using flux = f(time, wavelengths).
"""
import abc
import numpy as np
from lightcurvelynx.astro_utils.mag_flux import flux2mag, mag2flux
[docs]
class ZeroPadding(FluxExtrapolationModel):
"""Extrapolate by zero padding the results."""
def __init__(self):
super().__init__()
def _extrapolate(self, last_values, last_fluxes, query_values):
"""Evaluate the extrapolation given the last valid points(s) and a list of new
query points.
Note
----
This function does not care which axis is being extrapolated. The returned values are
always len(query_values) x len(last_fluxes) and may need to be transposed by the calling
function.
Parameters
----------
last_values : float or np.ndarray
The last valid value along the extrapolation axis at which the flux was predicted
(e.g., wavelength in AA or time in days).
last_fluxes : numpy.ndarray
A length N array of the flux values at the last valid time or wavelength (in nJy).
query_values : numpy.ndarray
A length M array of values along the extrapolation axis (times in days or wavelengths
in AA) at which to extrapolate.
Returns
-------
flux : numpy.ndarray
A N x M matrix of extrapolated values. Where M is the number of query points and
N is the number of flux values at the last valid point.
"""
N_len = len(last_fluxes)
M_len = len(query_values)
return np.zeros((N_len, M_len))
[docs]
class ConstantPadding(FluxExtrapolationModel):
"""Extrapolate using a constant value in nJy.
Attributes
----------
value : float
The value (in nJy) to use for the extrapolation.
"""
def __init__(self, value=0.0):
super().__init__()
if value < 0:
raise ValueError("Extrapolation value must be positive.")
def _extrapolate(self, last_values, last_fluxes, query_values):
"""Evaluate the extrapolation given the last valid points(s) and a list of new
query points.
Note
----
This function does not care which axis is being extrapolated. The returned values are
always len(query_values) x len(last_fluxes) and may need to be transposed by the calling
function.
Parameters
----------
last_values : float or np.ndarray
The last valid value along the extrapolation axis at which the flux was predicted
(e.g., wavelength in AA or time in days).
last_fluxes : numpy.ndarray
A length N array of the flux values at the last valid time or wavelength (in nJy).
query_values : numpy.ndarray
A length M array of values along the extrapolation axis (times in days or wavelengths
in AA) at which to extrapolate.
Returns
-------
flux : numpy.ndarray
A N x M matrix of extrapolated values. Where M is the number of query points and
N is the number of flux values at the last valid point.
"""
N_len = len(last_fluxes)
M_len = len(query_values)
return np.full((N_len, M_len), self.value)
[docs]
class LastValue(FluxExtrapolationModel):
"""Extrapolate using the last valid value along the desired axis."""
def __init__(self):
super().__init__()
def _extrapolate(self, last_values, last_fluxes, query_values):
"""Evaluate the extrapolation given the last valid points(s) and a list of new
query points.
Note
----
This function does not care which axis is being extrapolated. The returned values are
always len(query_values) x len(last_fluxes) and may need to be transposed by the calling
function.
Parameters
----------
last_values : float or np.ndarray
The last valid value along the extrapolation axis at which the flux was predicted
(e.g., wavelength in AA or time in days).
last_fluxes : numpy.ndarray
A length N array of the flux values at the last valid time or wavelength (in nJy).
query_values : numpy.ndarray
A length M array of values along the extrapolation axis (times in days or wavelengths
in AA) at which to extrapolate.
Returns
-------
flux : numpy.ndarray
A N x M matrix of extrapolated values. Where M is the number of query points and
N is the number of flux values at the last valid point.
"""
last_fluxes = np.asarray(last_fluxes).reshape(-1)
return np.tile(last_fluxes[:, np.newaxis], (1, len(query_values)))
[docs]
class LinearDecay(FluxExtrapolationModel):
"""Linear decay of the flux using the last valid point(s) down to zero.
Attributes
----------
decay_width : float or np.ndarray
The width of the decay region in Angstroms. The flux is
linearly decreased to zero over this range.
"""
def __init__(self, decay_width=100.0):
super().__init__()
if decay_width <= 0:
raise ValueError("decay_width must be positive.")
[docs]
self.decay_width = decay_width
def _extrapolate(self, last_values, last_fluxes, query_values):
"""Evaluate the extrapolation given the last valid points(s) and a list of new
query points.
Note
----
This function does not care which axis is being extrapolated. The returned values are
always len(query_values) x len(last_fluxes) and may need to be transposed by the calling
function.
Parameters
----------
last_values : float or np.ndarray
The last valid value along the extrapolation axis at which the flux was predicted
(e.g., wavelength in AA or time in days).
last_fluxes : numpy.ndarray
A length N array of the flux values at the last valid time or wavelength (in nJy).
query_values : numpy.ndarray
A length M array of values along the extrapolation axis (times in days or wavelengths
in AA) at which to extrapolate.
Returns
-------
flux : numpy.ndarray
A N x M matrix of extrapolated values. Where M is the number of query points and
N is the number of flux values at the last valid point.
"""
last_fluxes = np.asarray(last_fluxes).reshape(-1)
query_values = np.asarray(query_values)
dist = np.abs(query_values - last_values)
multiplier = np.clip(1.0 - (dist / self.decay_width), 0.0, 1.0)
flux = last_fluxes[:, np.newaxis] * multiplier[np.newaxis, :]
return flux
[docs]
class LinearDecayOnMag(FluxExtrapolationModel):
"""Linear decay of the converted magnitude using the last valid point(s) with a fixed decay
rate down to a specific magnitude threshold. This is generally used for extrapolating in
the time axis.
Attributes
----------
decay_rate : float or np.ndarray
The rate of the decay region in days. The magnitude is
linearly decreased to the mag_thres over this range.
mag_thres : float or np.ndarray
The mag threshold for the linear decay extraplation. Fluxes are capped at this value
for time/wavelength beyond this value.
"""
def __init__(self, decay_rate=0.02, mag_thres=40.0):
super().__init__()
if decay_rate <= 0:
raise ValueError("decay_rate must be positive.")
[docs]
self.decay_rate = decay_rate
[docs]
self.mag_thres = mag_thres
def _extrapolate(self, last_values, last_fluxes, query_values):
"""Evaluate the extrapolation given the last valid points(s) and a list of new
query points.
Note
----
This function does not care which axis is being extrapolated. The returned values are
always len(query_values) x len(last_fluxes) and may need to be transposed by the calling
function.
Parameters
----------
last_values : float or np.ndarray
The last valid value along the extrapolation axis at which the flux was predicted
(e.g., wavelength in AA or time in days).
last_fluxes : numpy.ndarray
A length N array of the flux values at the last valid time or wavelength (in nJy).
query_values : numpy.ndarray
A length M array of values along the extrapolation axis (times in days or wavelengths
in AA) at which to extrapolate.
Returns
-------
flux : numpy.ndarray
A N x M matrix of extrapolated values. Where M is the number of query points and
N is the number of flux values at the last valid point.
"""
last_fluxes = np.asarray(last_fluxes).reshape(-1)
last_fluxes = np.clip(last_fluxes, mag2flux(self.mag_thres), None)
last_mags = flux2mag(last_fluxes)
query_values = np.asarray(query_values)
dist = np.abs(query_values - last_values)
mag = np.clip(last_mags[:, np.newaxis] + dist[np.newaxis, :] * self.decay_rate, None, self.mag_thres)
return mag2flux(mag)
[docs]
class ExponentialDecay(FluxExtrapolationModel):
"""Exponential decay of the flux using the last valid point(s) down to zero.
f(t, λ) = f(t, λ_last) * exp(- rate * \\|λ - λ_last\\|)
Attributes
----------
rate : float
The decay rate in the exponential function.
"""
def __init__(self, rate):
super().__init__()
if rate < 0:
raise ValueError("Decay rate must be zero or positive.")
def _extrapolate(self, last_values, last_fluxes, query_values):
"""Evaluate the extrapolation given the last valid points(s) and a list of new
query points.
Note
----
This function does not care which axis is being extrapolated. The returned values are
always len(query_values) x len(last_fluxes) and may need to be transposed by the calling
function.
Parameters
----------
last_values : float or np.ndarray
The last valid value along the extrapolation axis at which the flux was predicted
(e.g., wavelength in AA or time in days).
last_fluxes : numpy.ndarray
A length N array of the flux values at the last valid time or wavelength (in nJy).
query_values : numpy.ndarray
A length M array of values along the extrapolation axis (times in days or wavelengths
in AA) at which to extrapolate.
Returns
-------
flux : numpy.ndarray
A N x M matrix of extrapolated values. Where M is the number of query points and
N is the number of flux values at the last valid point.
"""
last_fluxes = np.asarray(last_fluxes).reshape(-1)
query_values = np.asarray(query_values)
dist = np.abs(query_values - last_values)
multiplier = np.exp(-self.rate * dist)
flux = last_fluxes[:, np.newaxis] * multiplier[np.newaxis, :]
return flux
def _bin_rows_median(last_fluxes, nbin, *, nan_safe=True):
"""Bin the input fluxes on the first axis given number of bins and return the median values
of each bin. This is used for binning the last fluxes to avoid extrapolating to extreme
values.
Parameters
----------
last_fluxes : np.ndarray
A N x T array of the input fluxes to be binned.
nbin : int
Number of bins along the first axis.
nan_safe : bool, optional
If True, use np.nanmedian (ignore NaNs).
If False, use np.median.
Returns
-------
binned_fluxes : np.ndarray
A nbin x T array of the binned fluxes.
"""
last_fluxes = np.asarray(last_fluxes)
N, T = last_fluxes.shape
if nbin > N:
raise ValueError("nbin must be smaller or equal to N")
# Bin edges that evenly partition rows
edges = np.linspace(0, N, nbin + 1, dtype=int)
binned_fluxes = np.empty((nbin, T), dtype=float)
for b in range(nbin):
lo, hi = edges[b], edges[b + 1]
chunk = last_fluxes[lo:hi]
binned_fluxes[b] = np.nanmedian(chunk, axis=0) if nan_safe else np.median(chunk, axis=0)
return binned_fluxes
[docs]
class LinearFit(FluxExtrapolationModel):
"""Linear extrapolation based on a linear fit to the last few points.
Parameters
----------
nfit : int
The number of points to be used for extrapolation. (Default is 5)
nbin : int
The number of bins to be used to bin the last fluxes. This can be used to avoid extrapolating
to extreme values when models are not well-behaved in smaller bins.
"""
def __init__(self, nfit=5, nbin=None):
super().__init__()
def _extrapolate(self, last_values, last_fluxes, query_values):
"""Evaluate the extrapolation given the last valid points(s) and a list of new
query points.
Parameters
----------
last_values : np.ndarray
A T elements array of the last values along the extrapolation axis at which the flux was predicted
(e.g., wavelength in AA or time in days).
last_fluxes : ndarray
A length N x T matrix of the flux values at the last valid time or wavelength (in nJy).
query_values : ndarray
A length M array of values along the extrapolation axis (times in days or wavelengths
in AA) at which to extrapolate.
Returns
-------
flux : ndarray
A N x M matrix of extrapolated values. Where M is the number of query points and
N is the number of flux values at the last valid point.
"""
if len(last_values) <= 1:
raise ValueError("Need at least two points to extrapolate using this method.")
N = last_fluxes.shape[0]
if self.nbin is None:
binned_fluxes = last_fluxes
else:
# guard: can't have more bins than rows
nbin = int(min(self.nbin, N))
binned_fluxes = _bin_rows_median(last_fluxes, nbin=nbin, nan_safe=True)
A = np.column_stack([last_values, np.ones_like(last_values)])
B = np.array(binned_fluxes, dtype=float, copy=True).T
coeffs = np.linalg.lstsq(A, B, rcond=None)[0]
slope, intercept = coeffs
# (nbin, M)
flux_binned = slope[:, None] * query_values[None, :] + intercept[:, None]
flux_binned = np.clip(flux_binned, 0.0, None)
# Expand back to (N, M): row i gets its bin's curve
if self.nbin is None:
flux = flux_binned
else:
row_to_bin = (np.arange(N) * nbin) // N
flux = flux_binned[row_to_bin]
return flux
[docs]
class LinearFitOnMag(FluxExtrapolationModel):
"""Linear extrapolation based on a linear fit to the coverted magnitude of the last few points.
Parameters
----------
nfit : int
The number of points to be used for extrapolation. (Default is 5)
nbin : int
The number of bins to be used to bin the last fluxes. This can be used to avoid extrapolating
to extreme values when models are not well-behaved in smaller bins.
"""
def __init__(self, nfit=5, nbin=None):
super().__init__()
def _extrapolate(self, last_values, last_fluxes, query_values):
"""Evaluate the extrapolation given the last valid points(s) and a list of new
query points.
Parameters
----------
last_values : np.ndarray
A T elements array of the last values along the extrapolation axis at which the flux was predicted
(e.g., wavelength in AA or time in days).
last_fluxes : ndarray
A length N x T matrix of the flux values at the last valid time or wavelength (in nJy).
query_values : ndarray
A length M array of values along the extrapolation axis (times in days or wavelengths
in AA) at which to extrapolate.
Returns
-------
flux : ndarray
A N x M matrix of extrapolated values. Where M is the number of query points and
N is the number of flux values at the last valid point.
"""
if len(last_values) <= 1:
raise ValueError("Need at least two points to extrapolate using this method.")
N = last_fluxes.shape[0]
last_fluxes = np.clip(last_fluxes, 1.0e-40, None)
last_fluxes = flux2mag(last_fluxes)
if self.nbin is None:
binned_fluxes = last_fluxes
else:
# guard: can't have more bins than rows (otherwise some bins empty -> median NaN)
nbin = int(min(self.nbin, N))
binned_fluxes = _bin_rows_median(last_fluxes, nbin=nbin, nan_safe=True)
A = np.column_stack([last_values, np.ones_like(last_values)])
B = np.array(binned_fluxes, dtype=float, copy=True).T
coeffs = np.linalg.lstsq(A, B, rcond=None)[0]
slope, intercept = coeffs
# (nbin, M)
flux_binned = slope[:, None] * query_values[None, :] + intercept[:, None]
flux_binned = np.clip(flux_binned, 0.0, None)
if self.nbin is None:
flux = flux_binned
else:
# Expand back to (N, M): row i gets its bin's curve
row_to_bin = (np.arange(N) * nbin) // N
flux = flux_binned[row_to_bin]
return mag2flux(flux)