Source code for lightcurvelynx.astro_utils.passbands

"""The Passband and PassbandGroup objects store information about the filters used
to convert flux densities (as a function of wavelength) to bandfluxes. They also provide
methods for loading and manipulating passband data.
"""

import logging
import warnings
from pathlib import Path
from typing import Literal, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.integrate
from astropy import units as u
from astropy.io.votable import parse
from astroquery.svo_fps import SvoFps
from citation_compass import cite_function

from lightcurvelynx import _LIGHTCURVELYNX_DOWNLOAD_DATA_DIR
from lightcurvelynx.consts import lsst_filter_plot_colors
from lightcurvelynx.utils.data_download import download_data_file_if_needed

[docs] logger = logging.getLogger(__name__)
[docs] class PassbandGroup: """A group of passbands. The given passbands can come from a single survey or multiple surveys. As such, a given filter may appear in multiple passbands in a passband group. For example we might generate data from a combination of Rubin and DECCAM and want to use both their r filters. Thus the primary mapping is done by the passband's full name “{SURVEY}_{FILTER}”. Lookups by filter name are permitted in cases where the filter only occurs once in the group and thus the desired passband is unambiguous. Attributes ---------- passbands : dict of Passband A dictionary of Passband objects, where the keys are the full_names of the passbands (eg, "LSST_u"). table_dir : str The path to the directory containing the passband tables. waves : np.ndarray The union of all wavelengths in the passbands. _in_band_wave_indices : dict A dictionary mapping the passband name (eg, "LSST_u") to the indices of that specific passband's wavelengths in the full waves list. _filter_to_name : dict A dictionary mapping the filter name to a list of matching passband full names. """ def __init__(self, given_passbands, filters=None, **kwargs): """Construct a PassbandGroup object. Parameters ---------- given_passbands : Passband, list of Passband, or dict A list of all the passbands to include in this group. These can either be Passband objects or dictionaries of passband parameters with the following keys: - required: - survey : str - The name of the survey to which the passband belongs, e.g., "LSST". - filter_name : str - The name of the filter, e.g., "u". - one of the following: - table_data : np.ndarray - The transmission table data as a (N, 2) array where the first column is wavelengths and the second column is transmission values. - table_path : str or Path - The path of the file from which to load the passband. - table_url : str - The URL from which to download the passband file. - optional: - delta_wave : float - trim_quantile : float - units : str (either 'nm' or 'A') filters : list, optional A list of filters to include in this PassbandGroup. If None, includes all filters. Otherwise drops filters that do not occur and throws an error if a filter is missing. Used for loading a subset of the filters. **kwargs Additional keyword arguments to pass to the Passband constructor. """
[docs] self.passbands = {}
[docs] self._in_band_wave_indices = {}
[docs] self._filter_to_name = {}
# If we are given a single Passband object or a single dictionary, wrap it in a list. if isinstance(given_passbands, Passband): given_passbands = [given_passbands] elif isinstance(given_passbands, dict): given_passbands = [given_passbands] elif given_passbands is None or len(given_passbands) == 0: raise ValueError("No passbands provided to PassbandGroup.") for pb in given_passbands: if isinstance(pb, Passband): # If we are given a Passband object, add it directly. self.passbands[pb.full_name] = pb elif isinstance(pb, dict): # If we are given a dictionary of parameters, create a Passband from it. params = pb.copy() # Add any missing parameters from kwargs. for key, value in kwargs.items(): if key not in params: params[key] = value passband = Passband.from_file(**params) self.passbands[passband.full_name] = passband else: raise TypeError(f"Expected a Passband object or a dictionary of parameters. Got {type(pb)}.") # Prune any filters that are not on the given list and check for any missing filters. # We match on either the full name or the filter name. if filters is not None: filters = set(filters) filters_remaining = filters.copy() all_bands = list(self.passbands.keys()) for pb_name in all_bands: pb_obj = self.passbands[pb_name] if pb_name in filters: filters_remaining.discard(pb_name) elif pb_obj.filter_name in filters: filters_remaining.discard(pb_obj.filter_name) else: del self.passbands[pb_name] if len(filters_remaining) != 0: raise ValueError(f"The following filters were not found: {filters_remaining}") # Compute the unique points and bounds for the group. self._update_internal_data()
[docs] def __str__(self) -> str: """Return a string representation of the PassbandGroup.""" return f"PassbandGroup containing {len(self.passbands)} passbands: {', '.join(self.passbands.keys())}"
[docs] def __len__(self) -> int: return len(self.passbands)
[docs] def __getitem__(self, key): """Return the passband corresponding to a full name or filter name.""" if key in self.passbands: return self.passbands[key] elif key in self._filter_to_name: # If we are looking up the passband by filter name, we check # that the filter only appears in a single Passband object. pb_list = self._filter_to_name[key] if len(pb_list) > 1: raise KeyError( f"Filter {key} corresponds to multiple passbands: {pb_list}.\n" "Lookup the passband by full name." ) return self.passbands[pb_list[0]] else: raise KeyError(f"Unknown passband {key}")
[docs] def __contains__(self, key) -> bool: if key in self.passbands: return True elif key in self._filter_to_name: return True return False
[docs] def __iter__(self): for pb in self.passbands.values(): yield pb
@property
[docs] def filters(self) -> list: """Return a list of filter names in the passband group.""" return list(self._filter_to_name.keys())
@classmethod
[docs] def from_preprocessed_file(cls, file_path: Union[str, Path], **kwargs) -> "PassbandGroup": """Load a PassbandGroup from a single file containing multiple preprocessed passbands. The file should be a CSV file with columns: survey, filter_name, wavelength, transmission. The wavelengths should be in Angstroms. Parameters ---------- file_path : str or Path The path to the file containing the passband data. **kwargs Additional keyword arguments to pass to the Passband constructor. Returns ------- PassbandGroup A PassbandGroup object containing the passbands from the file. """ file_path = Path(file_path) if not file_path.is_file(): raise FileNotFoundError(f"{file_path} is not a valid file.") # Load the table using pandas and check that it has the required columns. df = pd.read_csv(file_path) required_columns = {"survey", "filter_name", "wavelength", "transmission"} if not required_columns.issubset(df.columns): missing_cols = required_columns - set(df.columns) raise ValueError(f"File {file_path} is missing required columns: {missing_cols}") pb_list = [] for (survey, filter_name), group in df.groupby(["survey", "filter_name"]): table_values = group[["wavelength", "transmission"]].to_numpy() pb = Passband( table_values, survey, filter_name, delta_wave=None, # Preserve original grid trim_quantile=None, # Do not trim units="A", # Units are Angstroms ) # Because we are loading from a pre-processed table (normalized system responses instead # of unnormalized system throughputs), we need to retain the original table. pb.normalized_system_response[:, 1] = table_values[:, 1] pb_list.append(pb) # Create a PassbandGroup from the DataFrame. return cls(pb_list, **kwargs)
@classmethod
[docs] def from_dir( cls, dir_path: Union[str, Path], *, filters: list | None = None, delta_wave: float | None = 5.0, trim_quantile: float | None = 1e-3, units: Literal["nm", "A"] | None = "A", ) -> "PassbandGroup": """Load the passbands from a directory where the directorty name corresponds to the survey and the file names correspond to the filters: path_to_survey_dir/survey_name/filter_name.dat Parameters ---------- dir_path : str or Path The path to the passband files including the survey directory. filters : list, optional A list of filters to include in this PassbandGroup. If None, includes all filters. Otherwise drops filters that do not occur and throws an error if a filter is missing. Used for loading a subset of the filters. delta_wave : float or None, optional The grid step of the wave grid, in angstroms. It is typically used to downsample transmission using linear interpolation. Default is 5 angstroms. If None the original grid is used. trim_quantile : float or None, optional The quantile to trim the transmission table by. For example, if trim_quantile is 1e-3, the transmission table will be trimmed to include only the central 99.8% of the area under the transmission curve. units : Literal['nm','A'], optional Denotes whether the wavelength units of the table are nanometers ('nm') or Angstroms ('A'). By default 'A'. Does not affect the output units of the class, only the interpretation of the provided passband table. """ dir_path = Path(dir_path) if not dir_path.is_dir(): raise ValueError(f"{dir_path} is not a valid directory.") # Iterate through the files in the directory. all_params = [] for entry in dir_path.iterdir(): if entry.is_file(): filter_name = entry.stem # If a list of filters is provided, skip the ones that are not in the list. if filters is None or filter_name in filters: params = { "survey": dir_path.name, "filter_name": filter_name, "table_path": dir_path / entry, "delta_wave": delta_wave, "trim_quantile": trim_quantile, "units": units, } all_params.append(params) return PassbandGroup(given_passbands=all_params, filters=filters)
@classmethod
[docs] def from_preset(cls, preset: str, *, table_dir=None, filters=None, **kwargs) -> "PassbandGroup": """Create a passband group from a pre-defined set of passbands. Parameters ---------- preset : str The name of the pre-defined set of passbands to load. table_dir : str, optional The path to the directory in which to store cached passband tables. If the passband exists in this directory, it will be loaded from there; otherwise it will be downloaded and saved in that directory. The full path to the tables will be {table_dir}/{survey}/{filter_name}.dat. filters : list, optional A list of filters to include in this PassbandGroup. If None, includes all filters. Otherwise drops filters that do not occur and throws an error if a filter is missing. Used for loading a subset of the filters. **kwargs Additional keyword arguments to pass to the Passband constructor. """ logger.info(f"Loading passbands from preset {preset}") # If we do not have the base table directory, use the default. table_dir = ( Path(table_dir) if table_dir is not None else Path(_LIGHTCURVELYNX_DOWNLOAD_DATA_DIR, "passbands") ) preset = preset.lower() if preset == "lsst": passbands = PassbandGroup._lsst_load_preset(table_dir, **kwargs) elif preset == "ztf": passbands = [] for filter_name in ["g", "r", "i"]: pb = Passband.from_sncosmo("ZTF", filter_name, f"ztf{filter_name}", **kwargs) passbands.append(pb) elif preset == "roman": passbands = PassbandGroup._roman_load_preset(table_dir, **kwargs) else: raise ValueError(f"Unknown passband preset: {preset}") # Build the actual PassbandGroup object. return cls(given_passbands=passbands, filters=filters, **kwargs)
@staticmethod @cite_function def _lsst_load_preset(table_dir, **kwargs): """Load the LSST passbands from the LSST preset. References ---------- LSST Filters: https://github.com/lsst/throughputs Parameters ---------- table_dir : str or Path The path to the directory in which to store cached passband tables. If the passband exists in this directory, it will be loaded from there; otherwise it will be downloaded and saved in that directory. The full path to the tables will be {table_dir}/LSST/{filter_name}.dat. **kwargs Additional keyword arguments to pass to the Passband constructor. Returns ------- passbands : list of Passband A list of Passband objects for the LSST filters. """ passbands = [] # Check that units are what is expected for this preset. if "units" not in kwargs: kwargs["units"] = "nm" elif kwargs["units"] != "nm": raise ValueError( "LSST passbands are expected to be in nanometers (nm). Please set units='nm' in the kwargs." ) for filter_name in ["u", "g", "r", "i", "z", "y"]: url = f"https://raw.githubusercontent.com/lsst/throughputs/main/baseline/total_{filter_name}.dat" pb = Passband.from_file( "LSST", filter_name, table_path=table_dir / "LSST" / f"{filter_name}.dat", table_url=url, **kwargs, ) passbands.append(pb) return passbands @staticmethod @cite_function def _roman_load_preset(table_dir, **kwargs): """Load the Roman passbands from the Roman preset. References ---------- Roman Filters: https://github.com/RomanSpaceTelescope/roman-technical-information Parameters ---------- table_dir : str or Path The path to the directory in which to store cached passband tables. If the passband exists in this directory, it will be loaded from there; otherwise it will be downloaded and saved in that directory. The full path to the tables will be {table_dir}/Roman/{filter_name}.dat. **kwargs Additional keyword arguments to pass to the Passband constructor. Returns ------- passbands : list of Passband A list of Passband objects for the Roman filters. """ passbands = [] # Check that units are what is expected for this preset. if kwargs.get("units", "micron") != "micron": raise ValueError( f"Roman passbands are expected to be in microns (micron), but got {kwargs.get('units')}" ) force_download = kwargs.get("force_download", False) table_path = table_dir / "Roman" / "roman_wfi_filters.ecsv" table_url = ( "https://raw.githubusercontent.com/RomanSpaceTelescope/roman-technical-information/main/data/" "WideFieldInstrument/Imaging/EffectiveAreas/Roman_effarea_v8_SCA01_20240301.ecsv" ) # Download the table if it does not exist or if force_download is True. success = download_data_file_if_needed(table_path, table_url, force_download=force_download) if not success: raise RuntimeError(f"Failed to download Roman passband table from {table_url}.") # Load the table, convert the wavelengths from microns to Angstroms, create Passband objects. # The table contains effective area in m^2, but since we normalize the transmission tables, # we can just use this directly. table = pd.read_csv(table_path, comment="#", sep=r"[\s,]+", engine="python") waves = table["Wave"].values * 10_000 # Convert microns to Angstroms for filter_name in ["F062", "F087", "F106", "F129", "F146", "F158", "F184", "F213"]: table_values = np.vstack([waves, table[filter_name].values.astype(float)]).T pb = Passband(table_values, "Roman", filter_name, **kwargs) passbands.append(pb) return passbands @classmethod @cite_function
[docs] def from_svo( cls, all_filters: list[str], *, delta_wave: float | None = 5.0, trim_quantile: float | None = 1e-3, **kwargs, ) -> "PassbandGroup": """Create a PassbandGroup object from the SVO Filter Profile Service given a list of full filter names in the form of "{FACILITY}/{INSTRUMENT}.{FILTER}" or a single string of the form "{FACILITY}/{INSTRUMENT}" to download all filters. References ---------- This research has made use of the SVO Filter Profile Service "Carlos Rodrigo", funded by MCIN/AEI/10.13039/501100011033/ through grant PID2023-146210NB-I00 * Rodrigo, C., Cruz, P., Aguilar, J.F., et al. 2024; https://ui.adsabs.harvard.edu/abs/2024A%26A...689A..93R/abstract * Rodrigo, C., Solano, E., Bayo, A., 2012; https://ui.adsabs.harvard.edu/abs/2012ivoa.rept.1015R/abstract * Rodrigo, C., Solano, E., 2020; https://ui.adsabs.harvard.edu/abs/2020sea..confE.182R/abstract Parameters ---------- all_filters : str or list[str] A string of the form "{FACILITY}/{INSTRUMENT}" to download all filters for a given facility and instrument or a list of full filter names to load from the SVO Filter Profile Service in the FORM "{FACILITY}/{INSTRUMENT}.{FILTER}" to download a subset of filters. If a list is provided, it can include filters from multiple surveys. delta_wave : float or None, optional The grid step of the wave grid, in angstroms. It is typically used to downsample transmission using linear interpolation. Default is 5 angstroms. If None the original grid is used. trim_quantile : float or None, optional The quantile to trim the transmission table by. For example, if trim_quantile is 1e-3, the transmission table will be trimmed to include only the central 99.8% of the area under the transmission curve. **kwargs Additional keyword arguments to pass to the Passband constructor. """ if isinstance(all_filters, str): # If we are given a string, we assume it is of the form "{FACILITY}/{INSTRUMENT}" # and we want to download all filters for that facility and instrument. if "/" not in all_filters: raise ValueError( f"Expected a string of the form '{{FACILITY}}/{{INSTRUMENT}}' to download " f"all those filters, but got {all_filters}" ) if "." in all_filters: raise ValueError( f"Expected a string of the form '{{FACILITY}}/{{INSTRUMENT}}' to download " f"all those filters, but got {all_filters}. If you want to specify a single filter, " f"please provide a list of one filter." ) # Do the actual downloading. facility, instrument = all_filters.split("/") all_filter_data = SvoFps.get_filter_list(facility=facility, instrument=instrument) all_filters = all_filter_data["filterID"].tolist() # Download and pre-process each passband. passband_list = [] for full_filter_name in all_filters: pb = Passband.from_svo( full_filter_name, delta_wave=delta_wave, trim_quantile=trim_quantile, **kwargs, ) passband_list.append(pb) return cls(given_passbands=passband_list, **kwargs)
[docs] def add_passband(self, passband) -> None: """Manually add a passband to the group. Parameters ---------- passband : Passband The passband to add to the group. """ self.passbands[passband.full_name] = passband self._update_internal_data()
def _update_internal_data(self) -> None: """Update the cached internal data.""" # Update the mapping of filter name to full_name. self._filter_to_name = {} for full_name, pb_obj in self.passbands.items(): if pb_obj.filter_name in self._filter_to_name: logger.info("Multiple passband objects detected for filter {pb_obj.filter_name}") self._filter_to_name[pb_obj.filter_name].append(full_name) else: self._filter_to_name[pb_obj.filter_name] = [full_name] self._update_waves() def _update_waves(self, threshold=1e-5) -> None: """Update the group's wave attribute to be the union of all wavelengths in the passbands and update the group's _in_band_wave_indices attribute, which is the indices of the group's wave grid that are in the passband's wave grid. Eg, if a group's waves are [11, 12, 13, 14, 15] and a single band's are [13, 14], we get [2, 3]. The indices are stored in the passband's _in_band_wave_indices attribute as either a tuple of two ints (lower, upper) or a 1D np.ndarray of ints. Parameters ---------- threshold : float The threshold for merging two "close" wavelengths. This is used to avoid problems with numerical precision. Default: 1e-5 """ if len(self.passbands) == 0: self.waves = np.array([]) else: # Compute the unique wavelengths (accounting for floating point error) by # sorting the union of all the waves, computing the gaps between each adjacent # pair (using 1e8 for before the first), and saving the points with a gap # larger than the threshold. all_waves = np.concatenate([passband.waves for passband in self.passbands.values()]) sorted_waves = np.sort(all_waves) gap_sizes = np.insert(sorted_waves[1:] - sorted_waves[:-1], 0, 1e8) self.waves = sorted_waves[gap_sizes >= threshold] # Update the mapping of each passband's wavelengths to the corresponding indices in the # unioned list of all wavelengths. self._in_band_wave_indices = {} for name, passband in self.passbands.items(): # We only want the fluxes that are in the passband's wavelength range # So, find the indices in the group's wave grid that are in the passband's wave grid lower, upper = passband.waves[0], passband.waves[-1] lower_index, upper_index = np.searchsorted(self.waves, [lower, upper]) # Check that this is the right grid after all (check will fail if passbands overlap and passbands # do not happen to be on the same phase of the grid; eg, even if the step is 10, if the first # passband starts at 100 and the second at 105, the passbands won't share the same grid) if np.array_equal(self.waves[lower_index : upper_index + 1], passband.waves): indices = slice(lower_index, upper_index + 1) else: indices = np.searchsorted(self.waves, passband.waves) self._in_band_wave_indices[name] = indices
[docs] def wave_bounds(self) -> tuple[float, float]: """Get the minimum and maximum wavelength for this group. Returns ------- min_wave : float The minimum wavelength. max_wave : float The maximum wavelength. """ min_wave = np.min(self.waves) max_wave = np.max(self.waves) return min_wave, max_wave
[docs] def mask_by_filter(self, filters) -> np.ndarray: """Compute a mask for whether a given observations is of interest for for a given analysis. For example this could be used to remove unneeded observations from an ObsTable. Parameters ---------- filters : list-like A length T array of filter names or full names. Returns ------- mask : numpy.ndarray A length T array of Booleans indicating whether the filter is of interest. """ filters = np.asarray(filters) full_name_mask = np.isin(filters, list(self.passbands.keys())) filter_name_mask = np.isin(filters, list(self._filter_to_name.keys())) return full_name_mask | filter_name_mask
[docs] def process_transmission_tables( self, *, delta_wave: float | None = 5.0, trim_quantile: float | None = 1e-3 ) -> None: """Process the transmission tables for all passbands in the group; recalculate group's wave attribute. This function is used to change the preprocessing of the transmission tables after the PassbandGroup has been initialized, such as using different delta_wave or trim_quantile values. Parameters ---------- delta_wave : float or None, optional The grid step of the wave grid. Default is 5.0. trim_quantile : float or None, optional The quantile to trim the transmission table by. For example, if trim_quantile is 1e-3, the transmission table will be trimmed to include only the central 99.8% of rows. """ for passband in self.passbands.values(): passband.process_transmission_table( delta_wave=delta_wave, trim_quantile=trim_quantile, ) self._update_internal_data()
[docs] def fluxes_to_bandflux(self, flux_density_matrix: np.ndarray, filter: str) -> np.ndarray: """Calculate bandfluxes for a single passband in the group. Parameters ---------- flux_density_matrix : np.ndarray A 2D array of flux densities where of shape T x W where the rows are times and columns are wavelengths. filter : str The name of the filter to evaluate. Returns ------- bandflux : np.ndarray A length T array of bandfluxes for the given passband. """ if filter not in self.passbands: if filter in self._filter_to_name: filter = self._filter_to_name[filter][0] else: raise ValueError(f"Filter {filter} not found in passband group.") passband = self.passbands[filter] # Evaluate the bandflux using only the wavelengths for this passband. wave_indices = self._in_band_wave_indices[filter] if wave_indices is None: # pragma: no cover raise ValueError( f"Passband {filter} does not have _in_band_wave_indices set. " "This should have been calculated in PassbandGroup._update_internal_data." ) single_band_fluxes = flux_density_matrix[:, wave_indices] bandflux = passband.fluxes_to_bandflux(single_band_fluxes) return bandflux
[docs] def fluxes_to_bandfluxes(self, flux_density_matrix: np.ndarray) -> np.ndarray: """Calculate bandfluxes for all passbands in the group. Parameters ---------- flux_density_matrix : np.ndarray A 2D array of flux densities where of shape T x W where the rows are times and columns are wavelengths. Returns ------- dict of np.ndarray A dictionary of bandfluxes with passband full names as keys and np.ndarrays of bandfluxes as values. """ if flux_density_matrix.size == 0 or len(self.waves) != len(flux_density_matrix[0]): flux_density_matrix_num_cols = 0 if flux_density_matrix.size == 0 else len(flux_density_matrix[0]) raise ValueError( f"PassbandGroup mismatched grids: Flux density matrix has {flux_density_matrix_num_cols} " f"columns, which does not match transmission table's {len(self.waves)} rows. Check that the " f"flux density matrix was calculated on the same grid as the transmission tables, which can " f"be accessed via the Passband's or PassbandGroup's waves attribute." ) # Compute the bandfluxes for each passband. bandfluxes = {} for full_name in self.passbands: bandfluxes[full_name] = self.fluxes_to_bandflux(flux_density_matrix, full_name) return bandfluxes
[docs] def to_file(self, file_path: Union[str, Path], *, overwrite: bool = False) -> None: """Save the entire passband group to a single file.""" if not overwrite and Path(file_path).exists(): raise FileExistsError(f"File {file_path} already exists. Use overwrite=True to overwrite it.") # Create a single pandas table with the information for each passband. all_data = [] for pb in self.passbands.values(): df = pd.DataFrame(pb.normalized_system_response, columns=["wavelength", "transmission"]) df["survey"] = pb.survey df["filter_name"] = pb.filter_name all_data.append(df) combined_df = pd.concat(all_data, ignore_index=True) combined_df.to_csv(file_path, index=False)
[docs] def plot(self, *, ax=None, figure=None, plot_transmission=False) -> None: """Plot the PassbandGroup on a single plot. Parameters ---------- ax : matplotlib.pyplot.Axes or None, optional Axes, None by default. figure : matplotlib.pyplot.Figure or None Figure, None by default. plot_transmission : bool Whether to plot the original transmission table instead of the normalized system response. Default is False, which plots the normalized system response. """ if ax is None: if figure is None: figure = plt.figure() ax = figure.add_axes([0, 0, 1, 1]) # Plot each passband. for pb_obj in self.passbands.values(): pb_obj.plot(ax=ax, plot_transmission=plot_transmission) ax.legend()
[docs] class Passband: """A passband contains information about its transmission curve and calculates its normalization. Attributes ---------- survey : str The survey to which the passband belongs: eg, "LSST". filter_name : str The name of the passband's filter: eg, "u". full_name : str The full name of the passband. This is the survey and filter concatenated: eg, "LSST_u". waves : np.ndarray The wavelengths of the transmission table in Angstroms. To be used when evaluating models to generate fluxes that will be passed to fluxes_to_bandflux. delta_wave : float or None The grid step of the wave grid, in angstroms, if the table is a uniform grid. The value is None if the grid is not uniform. transmission_table : np.ndarray A 2D array of wavelengths and transmissions. This is the system throughput table loaded from the file, and is neither interpolated nor normalized. normalized_system_response : np.ndarray A 2D array where the first col is wavelengths (Angstrom) and the second col is transmission values. This table is both interpolated to the _wave_grid and normalized to calculate phi_b(λ). """ def __init__( self, table_values: np.array, survey: str, filter_name: str, *, delta_wave: float | None = 5.0, trim_quantile: float | None = 1e-3, units: Literal["nm", "A"] | None = "A", ): """Construct a Passband object. Parameters ---------- table_values : np.ndarray, optional A 2D array of wavelengths (in the given units) and transmissions. survey : str The survey to which the passband belongs: eg, "LSST". filter_name : str The filter_name of the passband: eg, "u". delta_wave : float or None, optional The grid step of the wave grid, in angstroms. It is typically used to downsample transmission using linear interpolation. Default is 5 angstroms. If None the original grid is used. trim_quantile : float or None, optional The quantile to trim the transmission table by. For example, if trim_quantile is 1e-3, the transmission table will be trimmed to include only the central 99.8% of the area under the transmission curve. units : Literal['nm','A'], optional Denotes whether the wavelength units of the table are nanometers ('nm') or Angstroms ('A'). By default 'A'. Does not affect the output units of the class, only the interpretation of the provided passband table. """
[docs] self.survey = survey
[docs] self.filter_name = filter_name
[docs] self.full_name = f"{survey}_{filter_name}"
[docs] self.delta_wave = delta_wave
# Perform validation of the transmission table. if table_values.shape[1] != 2: raise ValueError("Passband requires an input table with exactly two columns.") diffs = np.diff(table_values[:, 0]) if np.any(diffs < 0.0): raise ValueError("Wavelengths in transmission table must be strictly increasing.") if np.any(diffs == 0.0): warnings.warn("Duplicate wavelengths found in transmission table; averaging values.", UserWarning) dup_inds = np.where(diffs == 0.0)[0] table_values[dup_inds, 1] = 0.5 * (table_values[dup_inds, 1] + table_values[dup_inds + 1, 1]) table_values = np.delete(table_values, dup_inds + 1, axis=0)
[docs] self.transmission_table = np.copy(table_values)
# Ensure the wavelengths are in Angstroms. if units == "nm": # Multiply the first column (wavelength) by 10.0 to convert to Angstroms self.transmission_table[:, 0] *= 10.0 elif units == "micron": # Multiply the first column (wavelength) by 10,000 to convert to Angstroms self.transmission_table[:, 0] *= 10_000.0 elif units != "A": raise ValueError(f"Unknown Passband units {units}") # Preprocess the passband. self.process_transmission_table(delta_wave=delta_wave, trim_quantile=trim_quantile)
[docs] def __str__(self) -> str: """Return a string representation of the Passband.""" return f"Passband: {self.full_name}"
[docs] def __eq__(self, other) -> bool: """Determine if two passbands have equal values for the processed tables.""" # Check that they are using the same wavelengths. if len(self.waves) != len(other.waves): return False if not np.allclose(self.waves, other.waves): return False # Check that they have the (approximately) same transmission tables. if self.normalized_system_response.shape != other.normalized_system_response.shape: return False if not np.allclose(self.normalized_system_response, other.normalized_system_response): return False return True
@classmethod
[docs] def from_file( cls, survey: str, filter_name: str, *, delta_wave: float | None = 5.0, trim_quantile: float | None = 1e-3, table_path: Union[str, Path] | None = None, table_url: str | None = None, units: Literal["nm", "A"] | None = "A", force_download: bool = False, ) -> "Passband": """Construct a Passband object from a file, downloading it if needed. Parameters ---------- survey : str The survey to which the passband belongs: eg, "LSST". filter_name : str The filter_name of the passband: eg, "u". delta_wave : float or None, optional The grid step of the wave grid, in angstroms. It is typically used to downsample transmission using linear interpolation. Default is 5 angstroms. If None the original grid is used. trim_quantile : float or None, optional The quantile to trim the transmission table by. For example, if trim_quantile is 1e-3, the transmission table will be trimmed to include only the central 99.8% of the area under the transmission curve. table_path : str, optional The path to the transmission table file. If None, the table path will be set to a default path; if no file exists at this location, the file will be downloaded from table_url. table_url : str, optional The URL to download the transmission table file. If None, the table URL will be set to a default URL based on the survey and filter_name. Default is None. units : Literal['nm','A'], optional Denotes whether the wavelength units of the table are nanometers ('nm') or Angstroms ('A'). By default 'A'. Does not affect the output units of the class, only the interpretation of the provided passband table. force_download : bool, optional If True, the transmission table will be downloaded even if it already exists locally. Default is False. """ if table_path is None: # If no path is given, use the default. table_path = Path( _LIGHTCURVELYNX_DOWNLOAD_DATA_DIR, "passbands", survey, f"{filter_name}.dat", ) else: table_path = Path(table_path) # Download the table if it does not exist or if force_download is True. success = download_data_file_if_needed(table_path, table_url, force_download=force_download) if not success: raise RuntimeError(f"Failed to download passband table from {table_url}.") # Load the table and create the passband. loaded_table = Passband.load_transmission_table(table_path, wave_units=units) return Passband( loaded_table, survey, filter_name, delta_wave=delta_wave, trim_quantile=trim_quantile, units="A", # All loaded tables are pre-converted to Angstroms )
@classmethod @cite_function
[docs] def from_sncosmo(cls, survey: str, filter_name: str, bandpass=None, **kwargs) -> "Passband": """Create a Passband object from an sncosmo.Bandpass object. Parameters ---------- survey : str The survey to which the passband belongs: eg, "LSST". filter_name : str The filter_name of the passband: eg, "u". bandpass : sncosmo.Bandpass or str The bandpass object from which to create the Passband object. **kwargs Additional keyword arguments (unused) References ---------- snocosmo.Bandpass: https://sncosmo.readthedocs.io/en/stable/api/sncosmo.Bandpass.html """ if bandpass is None: raise ValueError("bandpass must be provided as object or string") elif isinstance(bandpass, str): # Only import sncosmo if we need to. try: from sncosmo import get_bandpass except ImportError as err: # pragma: no cover raise ImportError( "sncosmo package is not installed by default. You can install it with " "`pip install sncosmo` or `conda install conda-forge::sncosmo`." ) from err bandpass = get_bandpass(bandpass) table = np.column_stack([bandpass.wave, bandpass.trans]) return Passband( table, survey, filter_name, trim_quantile=None, # Trimming is done in sncosmo units="A", # All sncosmo bandpasses are in Angstroms )
@classmethod @cite_function
[docs] def from_svo( cls, full_filter_name, *, delta_wave: float | None = 5.0, trim_quantile: float | None = 1e-3, **kwargs, ) -> "Passband": """Create a Passband object from the SVO [1]_, [2]_, [3]_ Filter Profile Service. References ---------- .. [1] Rodrigo, C., Cruz, P., Aguilar, J.F., et al. 2024; https://ui.adsabs.harvard.edu/abs/2024A%26A...689A..93R/abstract .. [2] Rodrigo, C., Solano, E., Bayo, A., 2012; https://ui.adsabs.harvard.edu/abs/2012ivoa.rept.1015R/abstract .. [3] Rodrigo, C., Solano, E., 2020; https://ui.adsabs.harvard.edu/abs/2020sea..confE.182R/abstract Parameters ---------- full_filter_name : str The full name of the survey and filter in the SVO database in the form of "{FACILITY}/{INSTRUMENT}.{FILTER}", e.g., "SLOAN/SDSS.u". delta_wave : float or None, optional The grid step of the wave grid, in angstroms. It is typically used to downsample transmission using linear interpolation. Default is 5 angstroms. If None the original grid is used. trim_quantile : float or None, optional The quantile to trim the transmission table by. For example, if trim_quantile is 1e-3, the transmission table will be trimmed to include only the central 99.8% of the area under the transmission curve. force_download : bool, optional If True, the transmission table will be downloaded even if it already exists locally. Default is False. **kwargs Additional keyword arguments to pass to the Passband constructor. """ # Parse the filter name to get the survey and filter and use it to construct the file # path and URL to the SVO database. survey, filter = full_filter_name.split(".") # Use astroquery to download the filter. data = SvoFps.get_transmission_data(full_filter_name) # If the wavelength is given in anything other than Angstroms, convert it to Angstroms. if data["Wavelength"].unit is not None and data["Wavelength"].unit != u.AA: data["Wavelength"] = data["Wavelength"].to(u.AA) # The data is returned as a table with columns "Wavelength" and "Transmission", each # of which can be masked. We convert this to a 2D numpy array, dropping the masked values. waves = data["Wavelength"].filled(np.nan).data trans = data["Transmission"].filled(np.nan).data valid = ~np.isnan(waves) & ~np.isnan(trans) table_data = np.vstack([waves[valid], trans[valid]]).T return cls( table_data, survey=survey, filter_name=filter, delta_wave=delta_wave, trim_quantile=trim_quantile, **kwargs, )
@staticmethod
[docs] def load_transmission_table( table_path: Union[str, Path], *, wave_units: Literal["nm", "A"] = "A", **kwargs, ) -> np.ndarray: """Load a transmission table from a file in either VOTable of ASCII format. ASCII tables must have 2 columns: wavelengths and transmissions; wavelengths must be strictly increasing. Parameters ---------- table_path : str or Path The path to the transmission table file. wave_units : Literal['nm','A'], optional Denotes whether the wavelength units of the table are nanometers ('nm') or Angstroms ('A'). By default 'A'. Does not affect the output units of the class, only the interpretation of the provided passband table. **kwargs : dict Additional keyword arguments to pass to the reader method. Returns ------- np.ndarray A 2D array of wavelengths (in Angstroms) and transmissions. """ logger.info(f"Loading passband from file: {table_path}") table_path = Path(table_path) if not table_path.exists(): raise FileNotFoundError(f"Transmission table not found at {table_path}") # Check if the file is in a VOTable format. if table_path.suffix in [".xml", ".vot", ".votable"]: table = parse(table_path).get_first_table() # Check that we have the correct data. if len(table.fields) != 2: raise ValueError("VOTable must have exactly 2 columns.") if not table.fields[0].name.lower().startswith("wave"): raise ValueError("VOTable first column must be named 'wavelength'.") # Transform the table from rows of (wavelength, transmission) tuples to a 2D array. loaded_table = np.zeros((len(table.array), 2), dtype=float) loaded_table[:, 0] = table.array[table.fields[0].name] loaded_table[:, 1] = table.array[table.fields[1].name] # Read the wavelength units from the VOTable. wave_unit_name = table.fields[0].unit.name.lower() if wave_unit_name in ["nm", "nanometers", "namo", "nanometer"]: wave_units = "nm" elif wave_unit_name in ["a", "aa", "angstrom", "angstroms"]: wave_units = "A" else: raise ValueError(f"Unsupported wavelength unit in VOTable: {wave_unit_name}") elif table_path.suffix in [".csv", ".ecsv", ".txt", ".dat"]: # Add default delimiter if not provided if (table_path.suffix == ".csv" or table_path.suffix == ".ecsv") and "delimiter" not in kwargs: kwargs["delimiter"] = "," # Load the table. loaded_table = np.loadtxt(table_path, **kwargs) elif table_path.suffix in [".parquet"]: loaded_table = pd.read_parquet(table_path, **kwargs).to_numpy() else: raise ValueError(f"Unsupported file format for transmission table: {table_path.suffix}") # Check that the table has the correct shape if loaded_table.size == 0 or loaded_table.shape[1] != 2: raise ValueError("Transmission table must have exactly 2 columns.") # If the table is given in nanometers, convert to Angstroms (by multiplying by 10.0). if wave_units == "nm": loaded_table[:, 0] *= 10.0 elif wave_units != "A": raise ValueError(f"Unknown wavelength units {wave_units}. Expected 'nm' or 'A'.") # Check that wavelengths are strictly increasing. If there are duplicates then # we average the values. diffs = np.diff(loaded_table[:, 0]) if np.any(diffs < 0.0): raise ValueError("Wavelengths in transmission table must be increasing.") if np.any(diffs == 0.0): warnings.warn("Duplicate wavelengths found in transmission table; averaging values.", UserWarning) dup_inds = np.where(diffs == 0.0)[0] loaded_table[dup_inds, 1] = 0.5 * (loaded_table[dup_inds, 1] + loaded_table[dup_inds + 1, 1]) loaded_table = np.delete(loaded_table, dup_inds + 1, axis=0) return loaded_table
[docs] def process_transmission_table( self, *, delta_wave: float | None = 5.0, trim_quantile: float | None = 1e-3, ) -> None: """Process the transmission table, transforming it to the desired wave grid and and computing a normalized system response from the throughput table. Parameters ---------- delta_wave : Optional[float] = 5.0 The grid step in Angstroms of the wave grid. Default is 5.0 Angstroms. trim_quantile : Optional[float] = 1e-3 The quantile to trim the transmission table by. For example, if trim_quantile is 1e-3, the transmission table will be trimmed to include only the central 99.8% of rows. """ interpolated_table = self.interpolate_transmission_table(self.transmission_table, delta_wave) trimmed_table = self.trim_transmission_by_quantile(interpolated_table, trim_quantile) self.normalized_system_response = self.compute_system_response_table(trimmed_table) self.waves = self.normalized_system_response[:, 0] self.delta_wave = delta_wave
@staticmethod
[docs] def interpolate_transmission_table(table: np.ndarray, delta_wave: float | None) -> np.ndarray: """Interpolate the transmission table to a new wave grid. Parameters ---------- table : np.ndarray A 2D array of wavelengths (in Angstroms) and transmissions. delta_wave : float or None The grid step in Angstroms of the wave grid. Returns ------- np.ndarray The 2D interpolated array of wavelengths (in Angstroms) and transmissions. """ # Don't interpolate if delta_wave is None or the table is already on the desired grid if delta_wave is None: return table if np.allclose(np.diff(table[:, 0]), delta_wave): return table # Regrid wavelengths to the new wave grid wavelengths = table[:, 0] lower_bound, upper_bound = wavelengths[0], wavelengths[-1] new_wavelengths = np.linspace( lower_bound, upper_bound, int((upper_bound - lower_bound) / delta_wave) + 1 ) # Interpolate the transmission table to the new wave grid spline = scipy.interpolate.InterpolatedUnivariateSpline(table[:, 0], table[:, 1], ext="raise", k=1) interpolated_transmissions = spline(new_wavelengths) return np.column_stack((new_wavelengths, interpolated_transmissions))
@staticmethod
[docs] def trim_transmission_by_quantile(table: np.ndarray, trim_quantile: float | None) -> np.ndarray: """Trim the transmission table so that it only includes the central (100 - 2*trim_quartile)% of rows. E.g., if trim_quantile is 1e-3, the transmission table will be trimmed to include only the central 99.8% of rows. Parameters ---------- table : np.ndarray A 2D array of wavelengths (in Angstroms) and transmissions. trim_quantile : float The quantile to trim the transmission table by. For example, if trim_quantile is 1e-3, the transmission table will be trimmed to include only the central 99.8% of rows. Must be greater than or equal to 0 and less than 0.5. Returns ------- np.ndarray A 2D array of wavelengths (in Angstroms) and transmissions. """ if trim_quantile is None or trim_quantile == 0.0: return table if trim_quantile < 0 or trim_quantile >= 0.5: raise ValueError(f"Trim quantile must be between 0 and 0.5; got {trim_quantile}.") # Separate wavelengths and transmissions wavelengths = table[:, 0] transmissions = table[:, 1] # Calculate the cumulative sum of the transmission values (area under the curve) cumulative_area = scipy.integrate.cumulative_trapezoid(transmissions, x=wavelengths) # Normalize cumulative area to range from 0 to 1 cumulative_area /= cumulative_area[-1] # Find indices where the cumulative area exceeds the trim quantiles lower_bound = max(np.searchsorted(cumulative_area, trim_quantile, side="right") - 1, 0) upper_bound = min(np.searchsorted(cumulative_area, 1 - trim_quantile), len(table) - 1) # Trim the table to the desired range trimmed_table = table[lower_bound : upper_bound + 1] return trimmed_table
@staticmethod
[docs] def compute_system_response_table(transmission_table: np.ndarray) -> np.ndarray: """Calculate the value of phi_b for all wavelengths in a transmission table. This is eq. 8 from "On the Choice of LSST Flux Units" (Ivezić et al.): φ_b(λ) = S_b(λ)λ⁻¹ / ∫ S_b(λ)λ⁻¹ dλ where S_b(λ) is the system response of the passband. Note we use transmission table as our S_b(λ). Parameters ---------- transmission_table : np.ndarray A 2D array of wavelengths (in Angstroms) and throughput values. Returns ------- np.ndarray A 2D array of wavelengths (in Angstroms) and normalized system response values. Raises ------ ValueError If the transmission table is the wrong size or the calculated denominator is zero. """ if transmission_table.size == 0: raise ValueError("Transmission table is empty; cannot normalize.") elif transmission_table.shape[0] == 1: raise ValueError(f"Cannot normalize transmission table with only one row: {transmission_table}.") elif transmission_table.ndim != 2 or transmission_table.shape[1] != 2: raise ValueError("Transmission table must be 2D array with exactly 2 columns.") wavelengths_angstrom = transmission_table[:, 0] transmissions = transmission_table[:, 1] # Calculate the numerators and denominator numerators = transmissions / wavelengths_angstrom denominator = scipy.integrate.trapezoid(numerators, x=wavelengths_angstrom) if denominator == 0: raise ValueError("Denominator is zero; cannot normalize transmission table.") # Calculate phi_b for each wavelength normalized_system_response = numerators / denominator return np.column_stack((wavelengths_angstrom, normalized_system_response))
[docs] def fluxes_to_bandflux( self, flux_density_matrix: np.ndarray, ) -> np.ndarray: """Calculate the bandflux for a given set of flux densities. This is eq. 7 from "On the Choice of LSST Flux Units" (Ivezić et al.): F_b = ∫ f(λ)φ_b(λ) dλ where f(λ) is the flux density of an object at the top of the atmosphere, and φ_b(λ) is the normalized system response for given band b." Parameters ---------- flux_density_matrix : np.ndarray A 2D or 3D array of flux densities. If the array is 2D it contains a single sample where the rows are the T times and columns are M wavelengths in Angstroms. If the array is 3D it contains S samples and the values are indexed as (sample_num, time, wavelength). Returns ------- bandfluxes : np.ndarray A 1D or 2D array. If the flux_density_matrix contains a single sample (2D input) then the function returns a 1D length T array where each element is the bandflux at the corresponding time. Otherwise the function returns a size S x T array where each entry corresponds to the value for a given sample at a given time. """ if flux_density_matrix.size == 0: raise ValueError("Empty flux density matrix used.") if len(flux_density_matrix.shape) == 2: w_axis = 1 flux_density_matrix_num_cols = flux_density_matrix.shape[1] elif len(flux_density_matrix.shape) == 3: w_axis = 2 flux_density_matrix_num_cols = flux_density_matrix.shape[2] else: raise ValueError("Invalid flux density matrix. Must be 2 or 3-dimensional.") # Check the number of wavelengths match. if len(self.waves) != flux_density_matrix_num_cols: raise ValueError( f"Passband mismatched grids: Flux density matrix has {flux_density_matrix_num_cols} " f"columns, which does not match the {len(self.waves)} rows in band {self.full_name}'s " f"transmission table. Check that the flux density matrix was calculated on the same grid as " f"the transmission tables, which can be accessed via the Passband's or PassbandGroup's waves " f"attribute." ) # Calculate the bandflux as ∫ f(λ)φ_b(λ) dλ, # where f(λ) is the flux density and φ_b(λ) is the normalized system response integrand = flux_density_matrix * self.normalized_system_response[:, 1] if self.delta_wave is not None: # If the grid is equal spaced, we can use a faster method of computing a rectangular # integration and removing half the first and last values (to make it trapezoidal). first_val = np.take(integrand, 0, axis=w_axis) last_val = np.take(integrand, -1, axis=w_axis) bandfluxes = (np.sum(integrand, axis=w_axis) - 0.5 * (first_val + last_val)) * self.delta_wave else: # Do the full integration. bandfluxes = scipy.integrate.trapezoid(integrand, x=self.waves, axis=w_axis) return bandfluxes
[docs] def plot( self, *, ax=None, figure=None, color=None, plot_transmission=False, ) -> None: """Plot the passband. Parameters ---------- ax : matplotlib.pyplot.Axes or None, optional Axes, None by default. figure : matplotlib.pyplot.Figure or None Figure, None by default. color : str or None, optional The color of the curve. plot_transmission : bool Plot the loaded transmission table instead of the normalized system response. Default is False (which plots the normalized system response). """ if ax is None: if figure is None: figure = plt.figure() ax = figure.add_axes([0, 0, 1, 1]) # If the color is provided, we use that. Otherwise we try # the LSST filter colors (or default to black). if color is None: color = lsst_filter_plot_colors.get(self.filter_name, "black") if plot_transmission: ax.plot( self.transmission_table[:, 0], # X values are the wavelength self.transmission_table[:, 1], # Y values are the transmission values. color=color, label=self.full_name, ) ax.set_ylabel("Transmission Value") else: ax.plot( self.normalized_system_response[:, 0], # X values are the wavelength self.normalized_system_response[:, 1], # Y values are the transmission values. color=color, label=self.full_name, ) ax.set_ylabel(r"Normalized Response, $1/\AA$") ax.set_xlabel(r"Wavelength, $\AA$") ax.set_ylim(0, None)