Source code for lightcurvelynx.astro_utils.sed_basis_models

"""Functions for creating *fake* SED basis functions for each filter.

These functions are primarily used to test and validate bandflux only models
by allowing them to compute some SED (that will be accurate when passed through
the filter curves). Users should NOT use these functions with transforms such
as redshift.
"""

import matplotlib.pyplot as plt
import numpy as np

from lightcurvelynx.astro_utils.passbands import Passband, PassbandGroup
from lightcurvelynx.consts import lsst_filter_plot_colors


[docs] class SEDBasisModel: """A simple class to hold SED basis model information. Attributes ---------- wavelengths : np.ndarray A 1-dimensional array of wavelengths (in Angstroms). sed_values : dict A dictionary mapping filter names to their corresponding SED values in each wavelength (of all_waves). """ def __init__(self, wavelengths, sed_values):
[docs] self.wavelengths = wavelengths.copy()
[docs] self.sed_values = {k: v.copy() for k, v in sed_values.items()}
for filter, sed in sed_values.items(): if len(sed) != len(wavelengths): raise ValueError( f"Length of SED values for filter {filter} ({len(sed)}) does not match " f"the length of wavelengths ({len(wavelengths)})." )
[docs] def __len__(self): """Return the number of filters in the SED basis model.""" return len(self.sed_values)
@property
[docs] def filters(self): """Return a list of all the filter names in the SED basis model.""" return list(self.sed_values.keys())
@classmethod
[docs] def from_box_approximation(cls, passbands, filters=None): """Create box-shaped SED basis functions. For each passband this creates a box shaped SED that does not overlap with any other passband. The height of the SED is normalized such that the total flux density will be 1.0 after passing through the passband. Parameters ---------- passbands : PassbandGroup The passband group to use for defining the light curve. filters : list, optional A list of filters to use for the model. If not provided, use all filters in the passband group. Returns ------- sed_basis_values : SEDBasisModel The basis model for this set of filters. """ if isinstance(passbands, Passband): passbands = PassbandGroup(given_passbands=[passbands]) if filters is None: filters = passbands.filters # Mark which wavelengths are used by each passband. waves_per_filter = np.zeros((len(filters), len(passbands.waves))) for idx, filter in enumerate(filters): if filter not in passbands.filters: raise ValueError(f"Filter {filter} not found in passband group.") # Get all of the wavelengths that have a non-negligible transmission value # for this filter and find their indices in the passband group. is_significant = passbands[filter].normalized_system_response[:, 1] > 1e-5 significant_waves = passbands[filter].waves[is_significant] indices = np.searchsorted(passbands.waves, significant_waves) # Mark all non-negligible wavelengths as used by this filter. waves_per_filter[idx, indices] = 1.0 # Find which wavelengths are used by multiple filters. filter_counts = np.sum(waves_per_filter, axis=0) # Create the sed values for each wavelength. sed_basis_values = {} for idx, filter in enumerate(filters): # Get the wavelengths that are used by ONLY this filter. valid_waves = (waves_per_filter[idx, :] == 1) & (filter_counts == 1) if np.sum(valid_waves) == 0: raise ValueError( f"Passband {filter} has no valid wavelengths where it: a) has a non-negligible " "transmission value (>0.001) and b) does not overlap with another passband." ) # Compute how much flux is passed through these wavelengths of this filter # and use this to normalize the sed values. filter_sed_basis = np.zeros((1, len(passbands.waves))) filter_sed_basis[0, valid_waves] = 1.0 total_flux = passbands.fluxes_to_bandflux(filter_sed_basis, filter) if total_flux[0] <= 0: raise ValueError(f"Total flux for filter {filter} is {total_flux[0]}.") sed_basis_values[filter] = filter_sed_basis[0, :] / total_flux[0] sed_basis = cls(passbands.waves, sed_basis_values) return sed_basis
[docs] def compute_sed(self, filter, wavelengths=None): """Compute the SED values for a given filter at a single time. Parameters ---------- filter : str The filter for which to compute the SED. wavelengths : np.ndarray, optional A 1-dimensional array of wavelengths (in Angstroms) at which to compute the SED. If None, use the internal wavelengths. Returns ------- np.ndarray A 1-dimensional array of SED values at each wavelength. """ if wavelengths is None: wavelengths = self.wavelengths sed_waves = np.interp( wavelengths, # The query wavelengths self.wavelengths, # All of the passband group's wavelengths self.sed_values[filter], # The SED values at each of the passband group's wavelengths left=0.0, # Do not extrapolate in wavelength right=0.0, # Do not extrapolate in wavelength ) return sed_waves
[docs] def compute_sed_from_bandfluxes(self, filters, bandfluxes, wavelengths=None): """Compute the SED values from given bandfluxes. Parameters ---------- filters : list of str A length T array of filter names for each time. bandfluxes : np.ndarray A length T array of bandflux values for each filter (in nJy). wavelengths : np.ndarray, optional A length W array of wavelengths (in Angstroms) at which to compute the SED. If None, use the internal wavelengths. Returns ------- flux_density : numpy.ndarray A length T x N matrix of observer frame SED values (in nJy). """ if wavelengths is None: wavelengths = self.wavelengths filters = np.asarray(filters) bandfluxes = np.asarray(bandfluxes) if filters.shape != bandfluxes.shape: raise ValueError("Length of filters and bandfluxes must match.") # Add in the SED from each filter at each time. flux_density = np.zeros((len(filters), len(wavelengths)), dtype=float) for filter in np.unique(filters): if filter not in self.sed_values: raise ValueError(f"Filter {filter} not found in SED basis model.") filter_mask = np.array(filters) == filter sed_values = self.compute_sed(filter, wavelengths) flux_density[filter_mask, :] = np.outer(bandfluxes[filter_mask], sed_values) return flux_density
[docs] def plot(self, ax=None, figure=None): """Plot the basis functions for the SED. This is a debugging function to help the user understand the SEDs produced by this model. Parameters ---------- ax : matplotlib.pyplot.Axes or None, optional Axes, None by default. figure : matplotlib.pyplot.Figure or None Figure, None by default. """ if ax is None: if figure is None: figure = plt.figure() ax = figure.add_axes([0, 0, 1, 1]) # Plot each passband. for filter_name, filter_curve in self.sed_values.items(): color = lsst_filter_plot_colors.get(filter_name, "black") ax.plot(self.wavelengths, filter_curve, color=color, label=filter_name) # Set the x and y axis labels. ax.set_xlabel("Wavelength (Angstroms)") ax.set_ylabel("SED (nJy)") ax.set_title("SED Basis Functions") ax.legend()