from pathlib import Path
import h5py
import numpy as np
from astropy import units as u
from citation_compass import CiteClass
from lightcurvelynx import _LIGHTCURVELYNX_DOWNLOAD_DATA_DIR
from lightcurvelynx.astro_utils.unit_utils import flam_to_fnu
from lightcurvelynx.effects.extinction import ExtinctionEffect
from lightcurvelynx.models.physical_model import SEDModel
from lightcurvelynx.utils.data_download import download_data_file_if_needed
_LIGHTCURVELYNX_BAYESN_CACHE_PATH = _LIGHTCURVELYNX_DOWNLOAD_DATA_DIR / "bayesn-data"
[docs]
class BayesnModel(SEDModel, CiteClass):
"""A bayesian model for supernova type Ia
The model is defined in (Mandel et al 2022) as::
flux(time, wave) = H_grid * 10 ** (-0.4 * W_grid) * 10 ** (-0.4 * (distmod _ m_abs))
This class is based on the bayesian implementation at:
https://github.com/bayesn/bayesn/blob/main/bayesn/bayesn_model.py
Parameterized values include:
* ra - The object's right ascension in degrees. [from BasePhysicalModel]
* dec - The object's declination in degrees. [from BasePhysicalModel]
* redshift - The object's redshift. [from BasePhysicalModel]
* t0 - The t0 of the zero phase, date. [from BasePhysicalModel]
* Amplitude - The fixed normalisation factor for distance modulus. [from Amplitude class]
* theta - The bayeSN theta parameter.
* Av - The bayeSN Av parameter.
* Rv - The bayeSN Rv parameter.
References
----------
* BayeSN: Mandel S., 2020 - https://arxiv.org/pdf/2008.07538
* M20 model: Mandel et al. 2022 (MNRAS 510, 3, 3939-3966)
* T21 model: Thorp et al. 2021 (MNRAS 508, 3, 4310-4331) - currently unused
* W22 model: Ward et al. 2023 (ApJ 956, 2, 111) - currently unused
* Hsiao spectral template: Hsiao E. Y., 2009 - https://arxiv.org/abs/1503.02293
Attributes
----------
_hsiao_phase: numpy.ndarray
The phase for hsiao template.
_hsiao_wave: numpy.ndarray
The wavelengths for hsiao template.
_hsiao_flux: numpy.ndarray
The baseline mean intrinsic SED provided by hsiao template.
_W0_: numpy.ndarray
The global W0 matrix.
_W1_: numpy.ndarray
The global W1 matrix.
_l_knots_: numpy.ndarray
The interpolation knots for wavelengths.
_tau_knots_: numpy.ndarray
The interpolation knots for phase.
Parameters
----------
theta: parameter
The bayeSN theta parameter.
Av: parameter
The bayeSN Av parameter.
Set of host extinction values for each SN.
Rv: parameter
The bayeSN Rv parameter.
Rv value for host extinction.
Amplitude: parameter
The distance modulus (m - M)
Normalized to have a absolute magnitude of -19.5
_M20_model_path: str
The path for the M20 model file directory from the bayesian model
Default: "bayesn-model-files/BAYESN.M20"
W0_filename : str
The file name of the W0 matrix.
Default: "W0.txt"
W1_filename: str
The file name of the W1 matrix.
Default: "W1.txt"
l_knots_filename
The file name of the knot values of wavelegnths for interpolation
Default: "l_knots.txt"
tau_knots_filename: str
The file name of the knot values of times for interpolation
Default: "tau_knots.txt"
hsiao_path_or_url: str
The path for the hsiao model template file directory.
Default: "bayesn-model-files/hsiao.h5"
**kwargs : dict, optional
Any additional keyword arguments.
"""
# A class variable for the units so we are not computing them each time.
_FLAM_UNIT = u.erg / u.second / u.cm**2 / u.AA
def __init__(
self,
theta=None,
Av=None,
Rv=None,
t0=0.0,
Amplitude=1.0,
M20_path_or_url="https://github.com/bayesn/bayesn-model-files/raw/refs/heads/main/BAYESN.M20/",
W0_filename="W0.txt",
W1_filename="W1.txt",
l_knots_filename="l_knots.txt",
tau_knots_filename="tau_knots.txt",
hsiao_path_or_url="https://github.com/bayesn/bayesn/raw/refs/heads/main/bayesn/data/hsiao.h5",
**kwargs,
):
super().__init__(t0=t0, **kwargs)
# define model specific parameters.
self.add_parameter("theta", theta, **kwargs)
self.add_parameter("Av", Av, **kwargs)
self.add_parameter("Rv", Rv, **kwargs)
self.add_parameter("Amplitude", Amplitude, **kwargs)
# load the data files.
if M20_path_or_url.startswith("http"):
m20_path = _LIGHTCURVELYNX_BAYESN_CACHE_PATH / "BAYESN.M20"
else:
m20_path = Path(M20_path_or_url)
if M20_path_or_url.startswith("http"):
for fname in [W0_filename, W1_filename, l_knots_filename, tau_knots_filename]:
if not download_data_file_if_needed(m20_path / fname, f"{M20_path_or_url}/{fname}"):
raise RuntimeError(f"Failed to download BayeSN data file from {M20_path_or_url}/{fname}.")
[docs]
self._W0_ = np.loadtxt(m20_path / W0_filename)
[docs]
self._W1_ = np.loadtxt(m20_path / W1_filename)
[docs]
self._l_knots_ = np.loadtxt(m20_path / l_knots_filename)
[docs]
self._tau_knots_ = np.loadtxt(m20_path / tau_knots_filename)
if hsiao_path_or_url.startswith("http"):
hsiao_path = _LIGHTCURVELYNX_BAYESN_CACHE_PATH / "hsiao.h5"
if not download_data_file_if_needed(hsiao_path, hsiao_path_or_url):
raise RuntimeError(f"Failed to download BayeSN data file from {hsiao_path_or_url}.")
else:
hsiao_path = Path(hsiao_path_or_url)
with h5py.File(hsiao_path, "r") as file:
data = file["default"]
self._hsiao_phase = data["phase"][()].astype("float64")
self._hsiao_wave = data["wave"][()].astype("float64")
self._hsiao_flux = data["flux"][()].astype("float64")
[docs]
def minphase(self, **kwargs):
"""Get the minimum supported rest-frame phase of the model in days.
Parameters
----------
**kwargs : dict
Additional keyword arguments, not used in this method.
Returns
-------
minphase : float or None
The minimum phase of the model (in days) or None
if the model does not have a defined minimum phase.
"""
return -20.0
[docs]
def maxphase(self, **kwargs):
"""Get the minimum supported rest-frame phase of the model in days.
Parameters
----------
**kwargs : dict
Additional keyword arguments, not used in this method.
Returns
-------
minphase : float or None
The minimum phase of the model (in days) or None
if the model does not have a defined minimum phase.
"""
return 85.0
# HELPER FUNCTIONs:
[docs]
def compute_invkd(self, x):
"""
Compute the operator matrix K^{-1}D to get second derivatives for natural cubic spline.
Parameters
----------
x : (n,) array_like
Knot positions (non-uniform, strictly increasing).
Returns
-------
invKD : (n, n) ndarray
Matrix such that M = invKD @ y gives second derivatives of y.
"""
n = len(x)
K = np.zeros((n - 2, n - 2)) # Tridiagonal matrix from spline equation
D = np.zeros((n - 2, n)) # Derivative matrix
# Construct tridiagonal matrix K
for j in range(1, n - 1):
i = j - 1
h0 = x[j] - x[j - 1]
h1 = x[j + 1] - x[j]
if i > 0:
K[i, i - 1] = h0 / 6 # Subdiagonal
K[i, i] = (h0 + h1) / 3 # Main diagonal
if i < n - 3:
K[i, i + 1] = h1 / 6 # Superdiagonal
# Construct matrix D for computing second derivatives
for j in range(1, n - 1):
i = j - 1
h0 = x[j] - x[j - 1]
h1 = x[j + 1] - x[j]
D[i, j - 1] = 1 / h0
D[i, j] = -1 / h0 - 1 / h1
D[i, j + 1] = 1 / h1
# SOlve linear system to get inverse matrix of K times D
invKD = np.zeros((n, n))
invKD[1:-1, :] = np.linalg.solve(K, D)
return invKD
[docs]
def natural_cubic_spline_basis_matrix_from_invkd(self, x, xq_array, invKD):
"""
Compute basis matrix J for multiple query points using precomputed second derivative matrix.
Parameters
----------
x : (n,) array_like
Knot positions.
xq_array : (m,) array_like
Query points.
invKD : (n, n) array_like
Precomputed matrix such that second_derivatives = invKD @ y.
Returns
-------
J : (m, n) ndarray
Basis matrix. Each row J[i, :] is the spline basis vector for xq_array[i].
"""
x = np.asarray(x)
xq_array = np.asarray(xq_array)
n = len(x)
m = len(xq_array)
J = np.zeros((m, n))
# Find indices of intervals for each query point
idxs = np.searchsorted(x, xq_array) - 1
idxs = np.clip(idxs, 0, n - 2)
# Distances and spline weights
x0 = x[idxs]
x1 = x[idxs + 1]
h = x1 - x0
a = (x1 - xq_array) / h
b = 1 - a
c = ((a**3 - a) * h**2) / 6
d = ((b**3 - b) * h**2) / 6
# Fill the linear part of the basis matrix
rows = np.arange(m)
J[rows, idxs] = a
J[rows, idxs + 1] = b
# Add contribution from second derivatives
J += c[:, None] * invKD[idxs] + d[:, None] * invKD[idxs + 1]
return J
[docs]
def compute_second_derivatives_1d(self, x, y):
"""
Compute natural cubic spline second derivatives (M) for 1D input.
Parameters
----------
x : (n,) array_like
Knot positions.
y : (n,) array_like
Values at knots.
Returns
-------
M : (n,) ndarray
Second derivatives at knots.
"""
n = len(x)
h = np.diff(x)
# Step1: Compute RHS alpha for the linear system
alpha = np.zeros(n)
for i in range(1, n - 1):
alpha[i] = (3 / h[i]) * (y[i + 1] - y[i]) - (3 / h[i - 1]) * (y[i] - y[i - 1])
# Step 2: Forward elimination (Thomas algorithm for tridiagonal system)
wave = np.ones(n)
mu = np.zeros(n)
z = np.zeros(n)
for i in range(1, n - 1):
wave[i] = 2 * (x[i + 1] - x[i - 1]) - h[i - 1] * mu[i - 1]
mu[i] = h[i] / wave[i]
z[i] = (alpha[i] - h[i - 1] * z[i - 1]) / wave[i]
# Step 3: Back substitution to solve for second derivatives
M = np.zeros(n)
for j in range(n - 2, 0, -1):
M[j] = z[j] - mu[j] * M[j + 1]
return M
[docs]
def compute_2d_second_derivatives(self, x, y, z):
"""
Compute second derivatives in both x and y directions for 2D natural cubic spline.
Parameters
----------
x : (n,) array_like
Knot positions in x.
y : (m,) array_like
Knot positions in y.
z : (n, m) array_like
Function values on 2D grid.
Returns
-------
Mx : (n, m) ndarray
Second derivatives with respect to x.
My : (n, m) ndarray
Second derivatives with respect to y.
"""
n, m = len(x), len(y)
Mx = np.zeros((n, m))
My = np.zeros((n, m))
# Compute second derivatives in x-direction (columns)
for j in range(m):
Mx[:, j] = self.compute_second_derivatives_1d(x, z[:, j])
# Compute second derivatives in y-direction (rows)
for i in range(n):
My[i, :] = self.compute_second_derivatives_1d(y, z[i, :])
return Mx, My
[docs]
def evaluate_natural_spline_2d_vectorized(self, x, y, z, Mx, My, xq, yq):
"""
Vectorized 2D natural cubic spline evaluation using second derivatives.
Parameters
----------
x : (n,) array_like
Knot positions in x.
y : (m,) array_like
Knot positions in y.
z : (n, m) array_like
Function values on 2D grid.
Mx : (n, m) array_like
Second derivatives w.r.t. x.
My : (n, m) array_like
Second derivatives w.r.t. y.
xq : (p,) array_like
Query positions in x.
yq : (q,) array_like
Query positions in y.
Returns
-------
Zq : (p, q) ndarray
Interpolated 2D surface values.
"""
xq = np.atleast_1d(xq)
yq = np.atleast_1d(yq)
nx, ny = len(x), len(y)
px, py = len(xq), len(yq)
xi_idx = np.clip(np.searchsorted(x, xq) - 1, 0, nx - 2)
yj_idx = np.clip(np.searchsorted(y, yq) - 1, 0, ny - 2)
Zq = np.zeros((px, py))
# Loop over y query points
for j in range(py):
y0_idx = yj_idx[j]
y1_idx = y0_idx + 1
hy = y[y1_idx] - y[y0_idx]
ay = (y[y1_idx] - yq[j]) / hy
by = 1 - ay
f_y0 = [] # Interpolated values along x at y0
f_y1 = [] # Interpolated values along x at y1
# Loop over x query points
for i in range(px):
x0_idx = xi_idx[i]
x1_idx = x0_idx + 1
hx = x[x1_idx] - x[x0_idx]
ax = (x[x1_idx] - xq[i]) / hx
bx = 1 - ax
# Values at 4 corners
z00 = z[x0_idx, y0_idx]
z10 = z[x1_idx, y0_idx]
z01 = z[x0_idx, y1_idx]
z11 = z[x1_idx, y1_idx]
Mx00 = Mx[x0_idx, y0_idx]
Mx10 = Mx[x1_idx, y0_idx]
Mx01 = Mx[x0_idx, y1_idx]
Mx11 = Mx[x1_idx, y1_idx]
# Interpolate in x for both fixed y0 and y1
fx0 = (ax**3 * Mx00 + bx**3 * Mx10) * hx / 6 + (
ax * (z00 - Mx00 * hx**2 / 6) + bx * (z10 - Mx10 * hx**2 / 6)
)
fx1 = (ax**3 * Mx01 + bx**3 * Mx11) * hx / 6 + (
ax * (z01 - Mx01 * hx**2 / 6) + bx * (z11 - Mx11 * hx**2 / 6)
)
f_y0.append(fx0)
f_y1.append(fx1)
# Interpolate final result in y
f_y0 = np.array(f_y0)
f_y1 = np.array(f_y1)
# Spline in y using fx0 and fx1
My0 = My[xi_idx, y0_idx]
My1 = My[xi_idx, y1_idx]
Zq[:, j] = (ay**3 * My0 + by**3 * My1) * hy / 6 + (
ay * (f_y0 - My0 * hy**2 / 6) + by * (f_y1 - My1 * hy**2 / 6)
)
return Zq
[docs]
def evaluate_2d_cubic_spline(self, x, y, z, xq, yq):
"""
Evaluate 2D natural cubic spline at given query points with auto second derivative computation.
Parameters
----------
x : (n,) array_like
Knot positions in x direction.
y : (m,) array_like
Knot positions in y direction.
z : (n, m) array_like
Function values at grid points.
xq : (p,) array_like
Query points in x.
yq : (q,) array_like
Query points in y.
Returns
-------
Zq : (p, q) ndarray
Interpolated values at query grid.
"""
# Automatically compute second derivatives in both directions
Mx, My = self.compute_2d_second_derivatives(x, y, z)
# Evaluate spling using the precomputed derivatives
return self.evaluate_natural_spline_2d_vectorized(x, y, z, Mx, My, xq, yq)
# MAIN FUNCTION:
[docs]
def compute_sed(self, times, wavelengths, graph_state, **kwargs):
"""Draw effect-free observations for this object.
Parameters
----------
times : numpy.ndarray
A length T array of rest frame timestamps in MJD.
wavelengths : numpy.ndarray, optional
A length N array of rest frame wavelengths (in angstroms).
graph_state : GraphState
An object mapping graph parameters to their values.
**kwargs : dict, optional
Any additional keyword arguments.
Returns
-------
flux_density : numpy.ndarray
A length T x N matrix of SED values (in nJy).
"""
W_0 = self._W0_
W_1 = self._W1_
l_knots = self._l_knots_
tau_knots = self._tau_knots_
hsiao_wave = self._hsiao_wave
hsiao_phase = self._hsiao_phase
hsiao_flux = self._hsiao_flux
params = self.get_local_params(graph_state)
tau = times - params["t0"]
within_phase_range = (tau >= self.minphase()) & (tau <= self.maxphase())
flux_density = np.zeros((len(times), len(wavelengths)))
W = W_0 + params["theta"] * W_1
invKD_l = self.compute_invkd(l_knots)
J_l = self.natural_cubic_spline_basis_matrix_from_invkd(l_knots, wavelengths, invKD_l)
invKD_tau = self.compute_invkd(tau_knots)
J_t = self.natural_cubic_spline_basis_matrix_from_invkd(tau_knots, tau[within_phase_range], invKD_tau)
J_t_T = np.atleast_2d(J_t).T
WJt = np.matmul(W, J_t_T)
W_grid = np.matmul(J_l, WJt)
W_grid = np.atleast_2d(W_grid).T
H_grid = self.evaluate_2d_cubic_spline(
hsiao_phase, hsiao_wave, hsiao_flux, tau[within_phase_range], wavelengths
)
flux_density[within_phase_range, :] = H_grid * 10 ** (-0.4 * W_grid)
# Apply dust extinction law
# Get ebv such that ebv = Av/Rv
ebv = params["Av"] / params["Rv"]
ext = ExtinctionEffect(
extinction_model="F99",
ebv=ebv,
frame="rest",
r_v=params["Rv"],
backend="dust_extinction",
)
flux_density = ext.apply(flux_density, tau, wavelengths, ebv)
# Apply the fixed distance modulus normalisation factor effect
flux_density = flux_density * params["Amplitude"]
# Convert to the correct units.
flux_density = flam_to_fnu(
flux_density,
wavelengths,
wave_unit=u.AA,
flam_unit=self._FLAM_UNIT,
fnu_unit=u.nJy,
)
return flux_density