"""Wrappers for the models defined in sncosmo.
https://github.com/sncosmo/sncosmo/blob/v2.10.1/sncosmo/models.py
https://sncosmo.readthedocs.io/en/stable/models.html
"""
from astropy import units as u
from citation_compass import CiteClass
from lightcurvelynx.astro_utils.unit_utils import flam_to_fnu
from lightcurvelynx.models.physical_model import SEDModel
[docs]
class SncosmoWrapperModel(SEDModel, CiteClass):
"""A wrapper for sncosmo models.
Parameterized values include:
* dec - The object's declination in degrees. [from BasePhysicalModel]
* distance - The object's luminosity distance in pc. [from BasePhysicalModel]
* ra - The object's right ascension in degrees. [from BasePhysicalModel]
* redshift - The object's redshift. [from BasePhysicalModel]
* t0 - The t0 of the zero phase, date. [from BasePhysicalModel]
Additional parameterized values are used for specific sncosmo models.
References
----------
* sncosmo - https://zenodo.org/records/14714968
* Individual models might require citation. See references in the sncosmo documentation.
Attributes
----------
source : sncosmo.Source
The underlying source model.
source_name : str
The name used to set the source.
source_param_names : list
A list of the source model's parameters that we need to set.
Parameters
----------
source_name : str
The name used to set the source.
node_label : str, optional
An identifier (or name) for the current node.
wave_extrapolation : FluxExtrapolationModel or tuple, optional
The extrapolation model(s) to use for wavelengths that fall outside the model's defined
bounds. If a tuple is provided, then it is expected to be of the form (before_model, after_model)
where before_model is the model for before the first valid wavelength and after_model is
the model for after the last valid wavelength. If None is provided the model will not try to
extrapolate, but rather call compute_sed() for all wavelengths.
time_extrapolation : FluxExtrapolationModel or tuple, optional
The extrapolation model(s) to use for times that fall outside the model's defined
bounds. If a tuple is provided, then it is expected to be of the form (before_model, after_model)
where before_model is the model for before the first valid time and after_model is
the model for after the last valid time. If None is provided the model will not try to
extrapolate, but rather call compute_sed() for all times.
seed : int, optional
The seed for a random number generator.
**kwargs : dict, optional
Any additional keyword arguments.
"""
# A class variable for the units so we are not computing them each time.
_FLAM_UNIT = u.erg / u.second / u.cm**2 / u.AA
def __init__(
self,
source_name,
node_label=None,
wave_extrapolation=None,
time_extrapolation=None,
seed=None,
**kwargs,
):
try:
from sncosmo.models import get_source
except ImportError as err: # pragma: no cover
raise ImportError(
"sncosmo package is not installed by default. To use the SncosmoWrapperModel, "
"please install sncosmo. For example, you can install it with "
"`pip install sncosmo` or `conda install conda-forge::sncosmo`."
) from err
# We explicitly ask for and pass along the PhysicalModel parameters such
# as node_label and wave_extrapolation so they do not go into kwargs
# and get added to the sncosmo model below.
super().__init__(
node_label=node_label,
wave_extrapolation=wave_extrapolation,
time_extrapolation=time_extrapolation,
seed=seed,
**kwargs,
)
[docs]
self.source_name = source_name
[docs]
self.source = get_source(source_name)
# Use the kwargs to initialize the sncosmo model's parameters.
[docs]
self.source_param_names = []
for key, value in kwargs.items():
if key not in self.setters:
self.add_parameter(key, value, description="Parameter for sncosmo model.")
if key in self.source.param_names:
self.source_param_names.append(key)
@property
[docs]
def param_names(self):
"""Return a list of the model's parameter names."""
return self.source.param_names
@property
[docs]
def parameter_values(self):
"""Return a list of the model's parameter values."""
return self.source.parameters
[docs]
def minphase(self, **kwargs):
"""Get the minimum phase of the model (in days relative to t0).
Parameters
----------
**kwargs : dict
Additional keyword arguments, not used in this method.
Returns
-------
minphase : float or None
The minimum phase of the model (in days relative to t0) or None
if the model does not have a defined minimum phase.
"""
return self.source.minphase()
[docs]
def maxphase(self, **kwargs):
"""Get the maximum phase of the model (in days relative to t0).
Parameters
----------
**kwargs : dict
Additional keyword arguments, not used in this method.
Returns
-------
maxphase : float or None
The maximum phase of the model (in days relative to t0) or None
if the model does not have a defined maximum phase.
"""
return self.source.maxphase()
[docs]
def minwave(self, **kwargs):
"""Get the minimum wavelength of the model.
Parameters
----------
**kwargs : dict
Additional keyword arguments, not used in this method.
Returns
-------
minwave : float or None
The minimum wavelength of the model (in angstroms) or None
if the model does not have a defined minimum wavelength.
"""
return self.source.minwave()
[docs]
def maxwave(self, **kwargs):
"""Get the maximum wavelength of the model.
Parameters
----------
**kwargs : dict
Additional keyword arguments, not used in this method.
Returns
-------
maxwave : float or None
The maximum wavelength of the model (in angstroms) or None
if the model does not have a defined maximum wavelength.
"""
return self.source.maxwave()
def _update_sncosmo_model_parameters(self, graph_state):
"""Update the parameters for the wrapped sncosmo model."""
local_params = graph_state.get_node_state(self.node_string, 0)
sn_params = {}
for name in self.source_param_names:
sn_params[name] = local_params[name]
self.source.set(**sn_params)
[docs]
def get(self, name):
"""Get the value of a specific parameter.
Parameters
----------
name : str
The name of the parameter.
Returns
-------
The parameter value.
"""
return self.source.get(name)
[docs]
def set(self, **kwargs):
"""Set the parameters of the model.
These must all be constants to be compatible with sncosmo.
Parameters
----------
**kwargs : dict
The parameters to set and their values.
"""
for key, value in kwargs.items():
if hasattr(self, key):
self.set_parameter(key, value)
else:
self.add_parameter(key, value, description="Parameter for sncosmo model.")
if key not in self.source_param_names:
self.source_param_names.append(key)
self.source.set(**kwargs)
def _sample_helper(self, graph_state, seen_nodes, rng_info=None):
"""Internal recursive function to sample the model's underlying parameters
if they are provided by a function or ParameterizedNode.
Calls ParameterNode's _sample_helper() then updates the parameters
for the sncosmo model.
Parameters
----------
graph_state : GraphState
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
seen_nodes : dict
A dictionary mapping nodes seen during this sampling run to their ID.
Used to avoid sampling nodes multiple times and to validity check the graph.
num_samples : int
A count of the number of samples to compute.
Default: 1
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
Raises
------
Raise a ValueError the sampling encounters a problem with the order of dependencies.
"""
super()._sample_helper(graph_state, seen_nodes, rng_info=rng_info)
self._update_sncosmo_model_parameters(graph_state)
[docs]
def compute_sed(self, times, wavelengths, graph_state=None, **kwargs):
"""Draw effect-free observations for this object.
Parameters
----------
times : numpy.ndarray
A length T array of rest frame timestamps.
wavelengths : numpy.ndarray, optional
A length N array of wavelengths (in angstroms).
graph_state : GraphState
An object mapping graph parameters to their values.
**kwargs : dict, optional
Any additional keyword arguments.
Returns
-------
flux_density : numpy.ndarray
A length T x N matrix of SED values (in nJy).
"""
params = self.get_local_params(graph_state)
self._update_sncosmo_model_parameters(graph_state)
# Provide debugging information about the wavelength range if the wavelengths are out of bounds.
if wavelengths.min() < self.minwave() or wavelengths.max() > self.maxwave():
raise ValueError(
"Wavelengths are out of the sncosmo model's valid range. "
f"Model wavelength range: [{self.minwave()}, {self.maxwave()}], "
f"Query wavelength range: [{wavelengths.min()}, {wavelengths.max()}]. Use the "
"'wave_extrapolation' parameter specify how to handle out-of-bounds wavelengths."
)
# Query the model and convert the output to nJy.
phase = times - params["t0"]
model_flam = self.source.flux(phase, wavelengths)
model_fnu = flam_to_fnu(
model_flam,
wavelengths,
wave_unit=u.AA,
flam_unit=self._FLAM_UNIT,
fnu_unit=u.nJy,
)
return model_fnu