import importlib
import logging
from pathlib import Path
import numpy as np
from astropy import units as u
from lightcurvelynx.astro_utils.passbands import PassbandGroup
from lightcurvelynx.astro_utils.snia_utils import (
DistModFromRedshift,
HostmassX1Func,
X0FromDistMod,
num_snia_per_redshift_bin,
)
from lightcurvelynx.astro_utils.unit_utils import flam_to_fnu, fnu_to_flam
from lightcurvelynx.math_nodes.np_random import NumpyRandomFunc
from lightcurvelynx.math_nodes.scipy_random import SamplePDF
from lightcurvelynx.models.sncosmo_models import SncosmoWrapperModel
from lightcurvelynx.models.snia_host import SNIaHost
from scipy.interpolate import interp1d
[docs]
logger = logging.getLogger(__name__)
[docs]
def construct_snia_source(oversampled_observations, zpdf):
"""Create a SNIA source/host pair with characteristics from an OpSim.
Parameters
----------
oversampled_observations : OpSim
The opsim data to use.
zpdf : interp1d
The PDF for the redshift.
Returns
-------
source : SEDModel
The SEDModel to sample.
"""
logger.info("Creating the source model.")
# Get the range in which t0 can occur.
t_min = oversampled_observations["time"].min()
t_max = oversampled_observations["time"].max()
# TODO: Extract the ra and dec center from the opsim in case that changes.
# Currently we are relying on the fact the test opsim has all pointings
# at (0.0, 0.0).
# Create a host galaxy.
host = SNIaHost(
ra=NumpyRandomFunc("uniform", low=-0.5, high=0.5), # all pointings RA = 0.0
dec=NumpyRandomFunc("uniform", low=-0.5, high=0.5), # all pointings Dec = 0.0
hostmass=NumpyRandomFunc("uniform", low=7, high=12),
redshift=SamplePDF(zpdf),
node_label="host",
)
distmod_func = DistModFromRedshift(host.redshift, H0=73.0, Omega_m=0.3)
x1_func = HostmassX1Func(host.hostmass)
c_func = NumpyRandomFunc("normal", loc=0, scale=0.02)
m_abs_func = NumpyRandomFunc("normal", loc=-19.3, scale=0.1)
x0_func = X0FromDistMod(
distmod=distmod_func,
x1=x1_func,
c=c_func,
alpha=0.14,
beta=3.1,
m_abs=m_abs_func,
node_label="x0_func",
)
sncosmo_modelname = "salt3"
source = SncosmoWrapperModel(
sncosmo_modelname,
t0=NumpyRandomFunc("uniform", low=t_min, high=t_max),
x0=x0_func,
x1=x1_func,
c=c_func,
ra=NumpyRandomFunc("normal", loc=host.ra, scale=0.01),
dec=NumpyRandomFunc("normal", loc=host.dec, scale=0.01),
redshift=host.redshift,
node_label="source",
)
return source
[docs]
def load_and_register_passband(passbands_dir, to_use):
"""Load the passband from files and register with sncosmo.
Parameters
----------
passbands_dir : str
The directory containing the passband files to use.
to_use : list
A list of the passbands to use.
Example: ["g", "r"]
Returns
-------
passbands : PassbandGroup
The loaded and processed PassbandGroup.
"""
if importlib.util.find_spec("sncosmo") is None: # pragma: no cover
raise ImportError("The sncosmo package is required to use the run_snia_end2end. ")
import sncosmo
passbands_dir = Path(passbands_dir)
passband_list = []
for band in to_use:
file_path = passbands_dir / "LSST" / f"{band}.dat"
passband_list.append({"filter_name": band, "table_path": file_path})
logger.info(f"Loading band {band} from {file_path}")
if len(passband_list) == 0:
raise ValueError("No passbands being loaded.")
# Do the actual loading and processing.
passbands = PassbandGroup(
passband_list,
survey="LSST",
units="nm",
trim_quantile=0.001,
delta_wave=1,
)
# Register sncosmo bandpasses
for f, passband in passbands.passbands.items():
sncosmo_bandpass = sncosmo.Bandpass(
*passband.normalized_system_response.T, name=f"lightcurvelynx_{f}"
)
sncosmo.register(sncosmo_bandpass, force=True)
return passbands
[docs]
def draw_single_random_sn(
source,
opsim,
passbands,
state=None,
rng_info=None,
):
"""Process a single random SN realization.
Parameters
----------
source : BasePhysicalModel
The BasePhysicalModel to use for the flux computation.
opsim : OpSim
The OpSim for the simulations
passbands : PassbandGroup
The passbands to use in generating the observations.
state : GraphState
The sample values to use. If None resamples the state.
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the default random number generator.
Returns
-------
res : dict
A dictionary of useful information about the run.
"""
if state is None:
state = source.sample_parameters(rng_info=rng_info)
# Extract some important parameters that we need to use.
ra = state["source"]["ra"]
dec = state["source"]["dec"]
t0 = state["source"]["t0"]
z = state["source"]["redshift"]
# Compute the rest wavelength information and save it.
wave_obs = passbands.waves
wavelengths_rest = wave_obs / (1.0 + z)
res = {"wavelengths_rest": wavelengths_rest}
# Find the times at which this source is seen.
obs_index = np.array(opsim.range_search(ra, dec, radius=1.75))
# Update obs_index to only include observations within SN lifespan
phase_obs = opsim["time"].iloc[obs_index] - t0
obs_index = obs_index[(phase_obs > -20 * (1.0 + z)) & (phase_obs < 50 * (1.0 + z))]
# Extract the timing and filter information for those observations, changing the
# match band names in passbands object.
times = opsim["time"].iloc[obs_index].to_numpy()
if len(times) == 0:
logger.warning(f"No overlap time in opsim for (ra,dec)=({ra:.2f},{dec:.2f})")
res["times"] = times
filters = opsim["filter"].iloc[obs_index].to_numpy(str)
filters = np.char.add("LSST_", filters)
res["filters"] = filters
# Compute the fluxes over all wavelengths.
flux_nJy = source.evaluate_sed(times, wave_obs, graph_state=state, rng_info=rng_info)
res["flux_nJy"] = flux_nJy
res["flux_flam"] = fnu_to_flam(
flux_nJy,
wave_obs,
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
fnu_unit=u.nJy,
)
res["flux_fnu"] = flux_nJy
# Compute the band_fluxes over just the given filters.
bandfluxes_perfect = source.evaluate_bandfluxes(passbands, times, filters, state)
res["bandfluxes_perfect"] = bandfluxes_perfect
new_vals, err_vals = opsim.noise_model.apply_noise(
bandfluxes_perfect,
obs_table=opsim,
indices=obs_index,
rng=rng_info,
)
res["bandfluxes_error"] = err_vals
res["bandfluxes"] = new_vals
res["state"] = state
return res
[docs]
def run_snia_end2end(
oversampled_observations,
passbands_dir,
solid_angle=0.0001,
nsample=1,
check_sncosmo=False,
rng_info=None,
):
"""Test that we can sample and create SN Ia simulation using the salt3 model.
Parameters
----------
oversampled_observations : OpSim
The opsim data to use.
passbands_dir : str
The name of the directory holding the passband information.
solid_angle : float
Solid angle for calculating number of SN.
nsample : int
The number of samples to test.
Default: 1
check_sncosmo : bool
Run the simulation a second time directly with sncosmo and compare the answers.
This should only be turned on for testing.
Default: False
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the default random number generator.
Returns
-------
res_list : dict
A dictionary of lists of sampling and result information.
passbands : PassbandGroup
The passbands used.
"""
if importlib.util.find_spec("sncosmo") is None: # pragma: no cover
raise ImportError("The sncosmo package is required to use the run_snia_end2end. ")
import sncosmo
if rng_info is None:
rng_info = np.random.default_rng()
# Compute the distribution from which to sample the redshift.
zmin = 0.1
zmax = 0.4
H0 = 70.0
Omega_m = 0.3
nsn, z = num_snia_per_redshift_bin(zmin, zmax, 100, H0=H0, Omega_m=Omega_m)
zpdf = interp1d(z, nsn, bounds_error=False, fill_value=0)
# Calculate nsample using SN Ia rate model
if nsample is None and solid_angle is not None:
nsn, _ = num_snia_per_redshift_bin(
zmin, zmax, znbins=1, solid_angle=solid_angle, H0=H0, Omega_m=Omega_m
)
nsample = int(nsn[0]) # Since znbins=1 there is only one bin
print(f"Drawing {nsample} samples from redshift {zmin} to {zmax}.")
source = construct_snia_source(oversampled_observations, zpdf)
passbands = load_and_register_passband(passbands_dir, to_use=["g", "r"])
logger.info(f"Sampling {nsample} states.")
sample_states = source.sample_parameters(num_samples=nsample)
res_list = []
for i in range(0, nsample):
current_state = sample_states.extract_single_sample(i)
res = draw_single_random_sn(
source,
opsim=oversampled_observations,
passbands=passbands,
state=current_state,
rng_info=rng_info,
)
# Copy out important parameter values.
p = {}
for parname in ["t0", "x0", "x1", "c", "redshift", "ra", "dec"]:
p[parname] = float(current_state["source"][parname])
p["hostmass"] = current_state["host.hostmass"]
p["distmod"] = current_state["x0_func.distmod"]
res["parameter_values"] = p
if check_sncosmo:
saltpars = {"x0": p["x0"], "x1": p["x1"], "c": p["c"], "z": p["redshift"], "t0": p["t0"]}
model = sncosmo.Model("salt3")
model.update(saltpars)
wave = passbands.waves
time = res["times"]
filters = res["filters"]
flux_sncosmo = model.flux(time, wave)
fnu_sncosmo = flam_to_fnu(
flux_sncosmo,
wave,
wave_unit=u.AA,
flam_unit=u.erg / u.second / u.cm**2 / u.AA,
fnu_unit=u.nJy,
)
np.testing.assert_allclose(res["flux_nJy"], fnu_sncosmo, atol=1e-8, rtol=1e-6)
np.testing.assert_allclose(res["flux_flam"], flux_sncosmo, atol=1e-30, rtol=1e-5)
# Skip test for negative fluxes
if np.all(flux_sncosmo > 0):
sncosmo_band_names = np.char.add("lightcurvelynx_", filters)
bandflux_sncosmo = model.bandflux(sncosmo_band_names, time, zpsys="ab", zp=8.9 + 2.5 * 9)
np.testing.assert_allclose(res["bandfluxes_perfect"], bandflux_sncosmo, rtol=1e-1)
res_list.append(res)
return res_list, passbands