"""Models that generate a constant SED or bandflux at all times."""
import numpy as np
from astropy import units as u
from lightcurvelynx.astro_utils.sed import SED
from lightcurvelynx.math_nodes.given_sampler import GivenValueSampler
from lightcurvelynx.models.physical_model import BandfluxModel, SEDModel
[docs]
class StaticSEDModel(SEDModel):
"""A StaticSEDModel randomly selects an SED at each evaluation and computes
the flux from that SED at all time steps.
Parameterized values include:
* dec - The object's declination in degrees. [from BasePhysicalModel]
* distance - The object's luminosity distance in pc. [from BasePhysicalModel]
* ra - The object's right ascension in degrees. [from BasePhysicalModel]
* redshift - The object's redshift. [from BasePhysicalModel]
* t0 - The t0 of the zero phase, date. [from BasePhysicalModel. Not used.]
Attributes
----------
sed_values : list of numpy.ndarray or SED
A list of SEDs from which to sample. Each SED is represented as a
two row numpy-array where the first row is wavelength and the
second is flux value, or as an instance of the SED class.
Parameters
----------
sed_values : list of SED objects
A single SED or a list of SEDs from which to sample.
weights : numpy.ndarray, optional
A length N array indicating the relative weight from which to select
an SED at random. If None, all SEDs will be weighted equally.
"""
def __init__(
self,
sed_values,
weights=None,
**kwargs,
):
# If only a single SED was passed, then put it in a list by itself.
if isinstance(sed_values, SED):
self.sed_values = [sed_values]
elif isinstance(sed_values, np.ndarray) and len(sed_values) == 2:
self.sed_values = [sed_values]
elif isinstance(sed_values, list | np.ndarray):
self.sed_values = sed_values
else:
raise ValueError("sed_values must be a single SED, a two row numpy array, or a list of SEDs.")
# Validate the SED input data.
for idx, sed in enumerate(self.sed_values):
# If the entry is a numpy array, turn it into an SED object.
if isinstance(sed, np.ndarray):
if sed.shape[0] != 2 or sed.shape[1] < 2:
raise ValueError(f"SED {idx} must be a two row numpy array of wavelength and flux.")
self.sed_values[idx] = SED(sed[0, :], sed[1, :])
elif not isinstance(sed, SED):
raise ValueError(f"SED {idx} must be an instance of the SED class or a two row numpy array.")
super().__init__(**kwargs)
# Create a parameter that indicates which SED was sampled in each simulation.
all_inds = [i for i in range(len(self.sed_values))]
self._sampler_node = GivenValueSampler(all_inds, weights=weights)
self.add_parameter(
"selected_idx",
value=self._sampler_node,
allow_gradient=False,
description="Index of the SED selected for sampling.",
)
[docs]
def __len__(self):
"""Get the number of SED value."""
return len(self.sed_values)
[docs]
def __iter__(self):
"""Iterate over the SED values."""
return iter(self.sed_values)
[docs]
def __getitem__(self, index):
"""Get the SED at the given index."""
return self.sed_values[index]
@classmethod
[docs]
def from_file(cls, sed_file, **kwargs):
"""Load a static SED from a file containing a two column array where the
first column is wavelength (in angstroms) and the second column is flux (in nJy).
Parameters
----------
sed_file : str or Path
The path to the SED file to load.
**kwargs : dict
Additional keyword arguments to pass to the StaticSEDModel constructor.
Returns
-------
StaticSEDModel
An instance of StaticSEDModel with the loaded SED data.
"""
sed_data = SED.from_file(sed_file)
return cls(sed_values=sed_data, **kwargs)
@classmethod
[docs]
def from_synphot(cls, sp_model, waves=None, **kwargs):
"""Generate the spectrum from a given synphot model.
References
----------
synphot (ascl:1811.001)
Parameters
----------
sp_model : synphot.SourceSpectrum
The synphot model to generate the spectrum from.
waves : numpy.ndarray, optional
A length N array of wavelengths (in angstroms) at which to sample the SED.
If None, the SED will be sampled at the wavelengths defined in the synphot model.
**kwargs : dict
Additional keyword arguments to pass to the StaticSEDModel constructor.
Returns
-------
StaticSEDModel
An instance of StaticSEDModel with the generated SED data.
"""
if waves is None:
waves = np.array(sp_model.waveset * u.angstrom)
sed_data = SED.from_synphot(sp_model, waves=waves)
return cls(sed_values=sed_data, **kwargs)
[docs]
def minwave(self, graph_state=None):
"""Get the minimum wavelength of the model.
Parameters
----------
graph_state : GraphState, optional
An object mapping graph parameters to their values. Not used
for this model.
Returns
-------
minwave : float or None
The minimum wavelength of the model (in angstroms) or None
if the model does not have a defined minimum wavelength.
"""
idx = self.get_param(graph_state, "selected_idx")
return self.sed_values[idx].minwave()
[docs]
def maxwave(self, graph_state=None):
"""Get the maximum wavelength of the model.
Parameters
----------
graph_state : GraphState, optional
An object mapping graph parameters to their values. Not used
for this model.
Returns
-------
maxwave : float or None
The maximum wavelength of the model (in angstroms) or None
if the model does not have a defined maximum wavelength.
"""
idx = self.get_param(graph_state, "selected_idx")
return self.sed_values[idx].maxwave()
[docs]
def compute_sed(self, times, wavelengths, graph_state):
"""Draw effect-free observer frame flux densities.
Parameters
----------
times : numpy.ndarray
A length T array of observer frame timestamps in MJD.
wavelengths : numpy.ndarray, optional
A length N array of observer frame wavelengths (in angstroms).
graph_state : GraphState
An object mapping graph parameters to their values.
Returns
-------
flux_density : numpy.ndarray
A length T x N matrix of observer frame SED values (in nJy).
"""
# Use the SED selected by the sampler node to compute the flux density.
model_ind = self.get_param(graph_state, "selected_idx")
sed_obj = self.sed_values[model_ind]
# At each time step we interpolate SED at the query wavelengths.
sed_fluxes = sed_obj.evaluate(wavelengths)
# We repeat the interpolated SED values at each time.
num_times = len(times)
num_waves = len(wavelengths)
flux_density = np.tile(sed_fluxes, num_times).reshape(num_times, num_waves)
return flux_density
[docs]
class StaticBandfluxModel(BandfluxModel):
"""A StaticBandfluxModel randomly selects a mapping of bandfluxes at each evaluation
and uses that at all time steps.
Parameterized values include:
* dec - The object's declination in degrees. [from PhysicalModel]
* distance - The object's luminosity distance in pc. [from PhysicalModel]
* ra - The object's right ascension in degrees. [from PhysicalModel]
* redshift - The object's redshift. [from PhysicalModel]
* t0 - The t0 of the zero phase, date. [from PhysicalModel. Not used.]
Attributes
----------
bandflux_values : list of dict
A list of bandflux mappings from which to sample. Each mapping is represented as a
dictionary where the key is the filter name and the value is the bandflux (in nJy).
Parameters
----------
bandflux_values : dict or list
A single bandflux mapping or a list of bandflux mappings from which to sample. Each mapping is
represented as a dictionary where the key is the filter name and the value is the bandflux (in nJy).
weights : numpy.ndarray, optional
A length N array indicating the relative weight from which to select
a model at random. If None, all models will be weighted equally.
"""
def __init__(
self,
bandflux_values,
weights=None,
**kwargs,
):
# If only a single bandflux mapping was passed, then put it in a list by itself.
if isinstance(bandflux_values, dict):
self.bandflux_values = [bandflux_values]
else:
self.bandflux_values = bandflux_values
super().__init__(**kwargs)
# Create a parameter that indicates which bandflux mapping was sampled in each simulation.
all_inds = [i for i in range(len(self.bandflux_values))]
self._sampler_node = GivenValueSampler(all_inds, weights=weights)
self.add_parameter("selected_idx", value=self._sampler_node, allow_gradient=False)
[docs]
def __len__(self):
"""Get the number of band flux values."""
return len(self.bandflux_values)
[docs]
def compute_bandflux(self, times, filter, state, rng_info=None):
"""Evaluate the model at the passband level for a single, given graph state.
Parameters
----------
times : numpy.ndarray
A length T array of observer frame timestamps in MJD.
filter : str
The name of the filter.
state : GraphState
An object mapping graph parameters to their values with num_samples=1.
rng_info : numpy.random._generator.Generator, optional
A given numpy random number generator to use for this computation. If not
provided, the function uses the node's random number generator.
"""
# Get the selected bandflux mapping index and values.
model_ind = self.get_param(state, "selected_idx")
model_bandflux = self.bandflux_values[model_ind]
# Fill in the bandflux values corresponding to the filter at each time.
bandflux = model_bandflux[filter]
return np.full(len(times), bandflux)