Source code for lightcurvelynx.obstable.obs_table

"""The top-level module for survey related data, such as pointing and observation conditions.
ObsTable class is a base class with specific implementations for different survey data,
such as Rubin and ZTF."""

import logging
import sqlite3
import warnings
from pathlib import Path

import astropy.units as u
import numpy as np
import pandas as pd
from astropy.coordinates import Latitude, Longitude
from mocpy import MOC
from regions import Region
from scipy.spatial import KDTree
from tqdm import tqdm

from lightcurvelynx.astro_utils.coordinate_utils import dedup_coords, ra_dec_to_cartesian
from lightcurvelynx.astro_utils.detector_footprint import DetectorFootprint
from lightcurvelynx.astro_utils.mag_flux import mag2flux
from lightcurvelynx.utils.io_utils import read_sqlite_table
from lightcurvelynx.utils.plotting import plot_moc

[docs] logger = logging.getLogger(__name__)
[docs] class ObsTable: """A class that stores a table of information about the observations in a survey, such as pointing and observation conditions. ObsTables are specialized for different surveys and include information about that survey (e.g. default noise parameters, default filter characteristics, etc.). They also include helper functions for common computations. Parameters ---------- table : dict or pandas.core.frame.DataFrame The table with all the survey information. Metadata can be included in the "lightcurvelynx_survey_data" entry of the attributes dictionary. colmap : dict, optional A mapping of standard column names to a list of possible names in the input table. Each value in the dictionary can be a string or a list of strings. For example, in Rubin's OpSim we might have the column "observationStartMJD" which maps to "time". In that case we would have an entry with key="time" and value="observationStartMJD". detector_footprint : astropy.regions.SkyRegion, Astropy.regions.PixelRegion, or DetectorFootprint, optional The footprint object for the instrument's detector. If None, no footprint filtering is done. Default is None. wcs : astropy.wcs.WCS, optional The WCS for the footprint. Either this or pixel_scale must be provided if a footprint is provided as a Astropy region. saturation_mags : dict, optional A dictionary mapping filter names to their saturation thresholds in magnitudes. The filters provided must match those in the table. If not provided, saturation effects will not be applied. **kwargs : dict Additional keyword arguments to pass to the constructor. This can include overrides of any of the survey values. Attributes ---------- survey_values : dict, optional A mapping for constant values for the survey used in various computations, such as readout noise and dark current. filters : np.ndarray The unique filters in the survey table (if provided). _table : pandas.core.frame.DataFrame The table with all the observation information mapped to standard column names. _colmap : dict A mapping of standard column names to their names in the input table. _inv_colmap : dict A dictionary mapping the custom column names back to the standard names. _spatial_data : scipy.spatial.KDTree or None A spatial_data structure of the survey pointings for fast spatial queries. We use the scipy kd-tree for most of the implementations instead of astropy's functions so we can directly control caching. _detector_footprint : DetectorFootprint, optional The footprint object for the instrument's detector. If None, no footprint filtering is done. Default is None. _saturation_mags : dict, optional The saturation thresholds in magnitudes for each filter. If unspecified, an instrument-specific default will be used, if available. """ _required_columns = ["ra", "dec", "time"] # Default survey values. Most of these are all None for the abstract base class. _default_survey_values = { "dark_current": None, "ext_coeff": None, "pixel_scale": None, "radius": None, "read_noise": None, "zp_per_sec": None, "zp_err_mag": 0.0, # Default of no noise floor. "survey_name": "Unknown", } # An alternate mapping of filter names to handle changes in schema. # The keys are the alternate names and the values are the standard names. _alt_filter_name_map = {} def __init__( self, table, *, colmap=None, detector_footprint=None, wcs=None, saturation_mags=None, **kwargs, ): # Create a copy of the table. if isinstance(table, dict): self._table = pd.DataFrame(table) else: self._table = table.copy() # Remap the columns to standard names. Start with the existing names (from the table) # and overwrite anything provided by the column map. The column map can have multiple # potential options for each standard name to handle changes in schema # (e.g. Rubin Opsim, DP1, DP2, etc.). # name_map will be the mapping of current -> standard name. # inv_colmap will be the mapping of standard -> current name. name_map = {col: col for col in self._table.columns}
[docs] self._inv_colmap = {}
[docs] self._colmap = colmap if colmap is not None else {}
all_cols = set(self._table.columns) if colmap is not None: for key, value in colmap.items(): # If the standard column name (key) could be taken from multiple options, # find the first one that exists in the table. if isinstance(value, list): match_value = None for val in value: if val in all_cols: match_value = val break if match_value is None: value = value[0] # Default to the first option if none found else: value = match_value if value in name_map: # Check for collisions (mapping a column to an existing column) if key in self._table.columns and key != value: raise ValueError(f"Trying to map {value} to {key}, but {key} is already a column.") # Add this entry to the list of column names that need to be remapped. name_map[value] = key # Save the inverse mapping as well self._inv_colmap[value] = key self._table.rename(columns=name_map, inplace=True) # Check that we have the required columns and filter out any NaNs. for col in self._required_columns: if col not in self._table.columns: raise KeyError(f"Missing required column: {col}") if self._table[col].isna().any(): warnings.warn( f"Found NaN values in required column '{col}'. " "Dropping rows with NaN values in this column." ) self._table = self._table.dropna(subset=[col]).reset_index(drop=True) logger.debug(f"ObsTable initialized with columns: {self._table.columns.tolist()}") # If we have a filter column, check whether we should remap and of the filter values. if "filter" in self._table.columns: for alt_name, standard_name in self._alt_filter_name_map.items(): mask = self._table["filter"] == alt_name if mask.any(): logger.debug(f"Remapping filter name '{alt_name}' to standard name '{standard_name}'.") self._table.loc[mask, "filter"] = standard_name # Save the survey values, with table metadata and keyword arguments overwriting the defaults.
[docs] self.survey_values = self._default_survey_values.copy()
if "lightcurvelynx_survey_data" in self._table.attrs: metadata = self._table.attrs["lightcurvelynx_survey_data"] if not isinstance(metadata, dict): raise TypeError("Got unexpected type for lightcurvelynx_survey_data") for key, value in metadata.items(): self.survey_values[key] = value for key, value in kwargs.items(): self.survey_values[key] = value logger.debug(f"ObsTable survey values: {self.survey_values}")
[docs] self.filters = np.unique(self._table["filter"]) if "filter" in self._table.columns else np.array([])
# Derive any additional noise columns the survey might need. self._derive_noise_columns() # Save the saturation thresholds.
[docs] self._saturation_mags = saturation_mags
# Update all of the cached data. self._update_cached_data() # Create the footprint if one is provided. if detector_footprint is not None: self.set_detector_footprint(detector_footprint, wcs=wcs) else: self._detector_footprint = None
[docs] def __len__(self): return len(self._table)
[docs] def __getitem__(self, key): """Access the underlying observation table by column or parameter name. This will return either a full column from the table or a survey parameter value. """ if key in self._table.columns: return self._table[key] if key in self._inv_colmap and self._inv_colmap[key] in self._table.columns: return self._table[self._inv_colmap[key]] if key in self.survey_values: return self.survey_values[key] raise KeyError(f"Column or parameter not found: {key}")
[docs] def __contains__(self, key): """Check if a column exists in the survey table or a parameter in the parameter table.""" if key in self._table.columns: return True if key in self._inv_colmap and self._inv_colmap[key] in self._table.columns: return True if key in self.survey_values and self.survey_values[key] is not None: return True return False
[docs] def copy(self): """Create a copy of the ObsTable.""" new_table = self._table.copy() new_survey_values = self.survey_values.copy() new_colmap = self._colmap.copy() new_detector_footprint = self._detector_footprint # Assuming this is immutable or we want to share it new_saturation_mags = self._saturation_mags.copy() if self._saturation_mags is not None else None return ObsTable( new_table, colmap=new_colmap, detector_footprint=new_detector_footprint, saturation_mags=new_saturation_mags, **new_survey_values, )
[docs] def head(self, n=5): """Return the first n rows of the observation table.""" return self._table.head(n)
[docs] def uses_footprint(self): """Return whether the ObsTable uses a detector footprint for filtering.""" return self._detector_footprint is not None
[docs] def clear_detector_footprint(self): """Clear the detector footprint, so no footprint filtering is done.""" self._detector_footprint = None
[docs] def set_detector_footprint(self, detector_footprint, wcs=None): """Set the detector footprint, so footprint filtering is done. Parameters ---------- detector_footprint : astropy.regions.SkyRegion, Astropy.regions.PixelRegion, or DetectorFootprint The footprint object for the instrument's detector. wcs : astropy.wcs.WCS, optional The WCS for the footprint. Either this or pixel_scale must be provided if a footprint is provided as a Astropy region. """ # Create the footprint if one is provided. if isinstance(detector_footprint, Region): pixel_scale = self.survey_values.get("pixel_scale", None) detector_footprint = DetectorFootprint(detector_footprint, wcs=wcs, pixel_scale=pixel_scale) self._detector_footprint = detector_footprint # Check that the radius is valid for the given footprint (if it exists). if self._detector_footprint is not None: fp_radius = self._detector_footprint.compute_radius() curr_radius = self.survey_values.get("radius", None) if curr_radius is None: self.survey_values["radius"] = fp_radius elif curr_radius < fp_radius: logger.info( f"Provided radius {curr_radius} is smaller than footprint radius {fp_radius}. " "Using the footprint radius instead." ) self.survey_values["radius"] = fp_radius else: logger.debug( f"Provided radius {curr_radius} is larger than footprint radius {fp_radius}. " "Using the provided radius." )
[docs] def get_value_per_row(self, key, *, indices=None, default=None): """Get the values for each row from the table or survey values (defaults). Parameters ---------- key : str The name of the column to retrieve. indices : numpy.ndarray, optional The indices of the rows for which to retrieve values. If None, retrieve all rows. Default: None default : any, optional The default value to use if the key is not found in the table or survey values. This can be None to indicate missing values. Default: None Returns ------- numpy.ndarray The values for each row in the table. """ if indices is None: indices = np.arange(len(self._table)) # Prioritize columns that are in the table. if key in self._table.columns: return self._table[key].iloc[indices].to_numpy() if key in self._inv_colmap and self._inv_colmap[key] in self._table.columns: return self._table[self._inv_colmap[key]].iloc[indices].to_numpy() # Otherwise fall back to the survey values if they are defined. value = self.survey_values.get(key, None) if value is None: return np.full((len(indices),), default) if isinstance(value, float | int): # Use the same value for all rows. return np.full((len(indices),), value) if isinstance(value, dict): # Map the values for each filter to the rows in the table. result = np.zeros(len(indices), dtype=float) for filt in self.filters: if filt not in value: raise ValueError(f"Dictionary for '{key}' does not have a value for filter '{filt}'") result[self._table["filter"].iloc[indices] == filt] = value[filt] return result raise TypeError(f"Unsupported type for '{key}': {type(value)}")
[docs] def safe_get_survey_value(self, key): """Get a survey value by key, checking that it is not None. Parameters ---------- key : str The key of the survey value to retrieve. """ value = self.survey_values.get(key, None) if value is None: raise ValueError( f"Survey value for {key} is not defined. This should be set when creating the object." ) return value
@property
[docs] def radius(self): """Return the radius if it exists.""" return self.survey_values.get("radius", None)
@radius.setter def radius(self, new_val): """Create a setter for radius.""" if new_val <= 0: raise ValueError(f"Invalid radius: {new_val}") if self._detector_footprint is not None: fp_radius = self._detector_footprint.compute_radius() if new_val < fp_radius: warnings.warn( f"Provided radius {new_val} is smaller than footprint radius {fp_radius}. " "This might lead to unexpected results." ) self.survey_values["radius"] = new_val @property
[docs] def columns(self): """Get the column names.""" return self._table.columns
@classmethod
[docs] def from_db(cls, filename, sql_query="SELECT * FROM observations", **kwargs): """Create an ObsTable object from the data in an db file. Reads data matching what is produced by write_db (and matching the RubinOpsim table). Parameters ---------- filename : str The name of the db file. sql_query : str The SQL query to use when loading the table. Default: "SELECT * FROM observations" kwargs : dict, optional Additional keyword arguments to pass to the Survey constructor. Returns ------- ObsTable A table with all of the pointing data. Raise ----- FileNotFoundError if the file does not exist. ValueError if unable to load the table. """ survey_data = read_sqlite_table(filename, table_name=None, sql_query=sql_query) return cls(survey_data, **kwargs)
@classmethod
[docs] def from_parquet(cls, filename, **kwargs): """Create an ObsTable object from a parquet file. Parameters ---------- filename : str The name of the parquet file to read. kwargs : dict, optional Additional keyword arguments to pass to the Survey constructor. Returns ------- ObsTable A table with all of the pointing data. """ if not Path(filename).is_file(): raise FileNotFoundError(f"File {filename} not found.") survey_data = pd.read_parquet(filename) return cls(survey_data, **kwargs)
[docs] def estimate_coverage(self, *, radius=None, max_depth=12, use_footprint=False): """Estimate the sky coverage of the observations in the ObsTable. This is an approximate calculation based on a constructed MOC at a given depth. Parameters ---------- radius : float, optional The radius to use for each image (in degrees). Only used if use_footprint is False. If None, the radius from the survey values will be used. max_depth : int, optional The maximum depth of the MOC. Default is 12. use_footprint : bool, optional Whether to use the detector footprint to build the MOC. If True, the footprint will be used to compute the MOC regions for each pointing. If False, a simple cone with the given radius will be used. Returns ------- coverage : float The estimated sky coverage in square degrees. """ moc = self.build_moc(radius=radius, max_depth=max_depth, use_footprint=use_footprint) coverage = moc.sky_fraction * 41253.0 # Approximate sky area in deg^2 return coverage
[docs] def build_moc( self, *, duplicate_threshold=100.0 / 3600.0, max_depth=10, radius=None, use_footprint=False, ): """Build a Multi-Order Coverage Map from the regions in the data set. The MOCs can be either from simple cones around each pointing or from the detector footprint if available. The code does not currently support rotation when dealing with detector footprints. Note ---- This can be **very** slow for large ObsTables or high max_depth values, especially when using detector footprints. Parameters ---------- duplicate_threshold : float, optional The threshold to use for identifying duplicate pointings (in degrees). Default is 100.0 / 3600.0 degrees = 100 arcseconds. max_depth : int, optional The maximum depth of the MOC. Default is 10. radius : float, optional The radius to use for each image (in degrees). Only used if use_footprint is False. If None, the radius from the survey values will be used. use_footprint : bool, optional Whether to use the detector footprint to build the MOC. If True, the footprint will be used to compute the MOC regions for each pointing. If False, a simple cone with the given radius will be used. Returns ------- MOC The Multi-Order Coverage Map constructed from the data set. """ logger.debug( f"Building MOC from ObsTable data: Depth={max_depth}, use_footprint={use_footprint}, " f"duplicate_threshold={duplicate_threshold}" ) # Deduplicate near-matching pointings to save computation time. ra = self._table["ra"].to_numpy() dec = self._table["dec"].to_numpy() if duplicate_threshold > 0.0: ra, dec, _ = dedup_coords(ra, dec, threshold=duplicate_threshold) logger.debug(f"Filtered {len(self._table) - len(ra)} duplicate pointings for MOC construction.") logger.debug(f"Building MOC from {len(ra)} unique pointings.") if not use_footprint or self._detector_footprint is None: radius = radius if radius is not None else self.survey_values.get("radius", None) if radius is None: raise ValueError("Radius must be provided for MOC construction or as a default. Got None.") moc = MOC.from_cones( lon=Longitude(ra, unit="deg"), lat=Latitude(dec, unit="deg"), radius=radius * u.deg, max_depth=max_depth, delta_depth=0, union_strategy="large_cones", ) else: # The combination of arbitrary rotations and mocpy does not currently work together # (e.g. a rotated rectangle of 90.1 degrees will fail, because the max rotation is capped # at pi radians). So we ignore rotation for now. if "rotation" in self._table.columns: warnings.warn( "MOC construction with footprint does not support rotation. Ignoring rotation values." ) moc = None for curr_ra, curr_dec in tqdm( zip(ra, dec, strict=False), total=len(ra), desc="Evaluating Region" ): sky_region, _ = self._detector_footprint.compute_sky_region(curr_ra, curr_dec) new_moc = MOC.from_astropy_regions(sky_region, max_depth=max_depth) moc = new_moc if moc is None else MOC.union(moc, new_moc) return moc
[docs] def plot_footprint( self, *, depth=14, fig=None, ax=None, use_footprint=False, **kwargs, ): """Plot the MOC footprint using matplotlib. Parameters ---------- depth : int, optional The healpix depth to use for plotting. Default is 14. fig : matplotlib.figure.Figure, optional An existing matplotlib figure to use. If None, a new figure is created. ax : matplotlib.pyplot.Axes or None, optional The axes to use for the plot. If None, new axes will be created on the figure. use_footprint : bool, optional Whether to use the detector footprint to build the MOC. If True, the detector footprint will be used. Default is False. **kwargs : dict, optional Additional keyword arguments to pass to the plot_moc function. Returns ------- fig: matplotlib.figure.Figure The figure containing the plot. ax: matplotlib.pyplot.Axes The axes containing the plot. """ moc = self.build_moc(max_depth=depth, use_footprint=use_footprint) fig, ax = plot_moc(moc, fig=fig, ax=ax, **kwargs) return fig, ax
def _update_cached_data(self): """Update any cached data based on the current table and survey values. This can be used any time the underlying table is updated, such as when rows are filtered. """ # Reset the index column if it exists. self._table = self._table.reset_index(drop=True) # Rebuild the list of filters. self.filters = np.unique(self._table["filter"]) if "filter" in self._table.columns else np.array([]) # Build the kd-tree (or other spatial data structure). self._spatial_data = None self._build_spatial_data() def _build_spatial_data(self): """Construct the KD-tree from the ObsTable.""" # Convert the pointings to Cartesian coordinates on a unit sphere. x, y, z = ra_dec_to_cartesian(self._table["ra"].to_numpy(), self._table["dec"].to_numpy()) cart_coords = np.array([x, y, z]).T # Construct the kd-tree. self._spatial_data = KDTree(cart_coords) def _derive_noise_columns(self): """Derive any missing noise-related columns (e.g. zero points) from the existing columns and survey values. Default implementation does not produce a zeropoint column. Subclasses should override this method with a survey specific computation. """ pass
[docs] def add_column(self, colname, values, *, overwrite=False): """Add a column to the current data table. Parameters ---------- colname : str The name of the new column. values : int, float, str, list, or numpy.ndarray The value(s) to add. overwrite : bool Overwrite the column is it already exists. Default: False """ if colname in self._table.columns and not overwrite: raise KeyError(f"Column {colname} already exists.") # If the input is a scalar, turn it into an array of the correct length if np.isscalar(values): values = np.full((len(self._table)), values) self._table[colname] = values
[docs] def write_db(self, filename, *, tablename="observations", overwrite=False): """Write out the observation table as a database to a given SQL table. Parameters ---------- filename : str The name of the db file. tablename : str The table to which to write. Default: "observations" overwrite : bool Overwrite the existing DB file. Default: False Raise ----- FileExistsError if the file already exists and overwrite is False. """ if_exists = "replace" if overwrite else "fail" con = sqlite3.connect(filename) try: self._table.to_sql(tablename, con, if_exists=if_exists) except Exception: raise ValueError("Database write failed.") from None con.close()
[docs] def write_parquet(self, filename, *, overwrite=False): """Write out the observation table as a parquet file. Parameters ---------- filename : str The name of the parquet file. overwrite : bool Overwrite the existing parquet file. Default: False Raise ----- FileExistsError if the file already exists and overwrite is False. """ if not overwrite and Path(filename).is_file(): raise FileExistsError(f"File {filename} already exists.") # Save all the survey data as metadata. self._table.attrs["lightcurvelynx_survey_data"] = self.survey_values self._table.to_parquet(filename)
[docs] def time_bounds(self): """Returns the min and max times for all observations in the ObsTable. Returns ------- t_min, t_max : float, float The min and max times for all observations in the ObsTable. """ t_min = self._table["time"].min() t_max = self._table["time"].max() return t_min, t_max
[docs] def filter_rows(self, rows): """Filter the rows in the ObsTable to only include those indices that are provided in a list of row indices (integers) or marked True in a mask. Parameters ---------- rows : numpy.ndarray Either a Boolean array of the same length as the table or list of integer row indices to keep. Returns ------- self : ObsTable The filtered ObsTable object. """ # Check if we are dealing with a mask of a list of indices. rows = np.asarray(rows) if rows.dtype == bool: if len(rows) != len(self._table): raise ValueError( f"Mask length mismatch. Expected {len(self._table)} rows, but found {len(rows)}." ) mask = rows else: mask = np.full((len(self._table),), False) mask[rows] = True # Filter the rows. Update all of the cached data. self._table = self._table[mask] self._update_cached_data() return self
[docs] def is_observed(self, query_ra, query_dec, *, radius=None, t_min=None, t_max=None): """Check if the query point(s) fall within the field of view of any pointing in the ObsTable. Parameters ---------- query_ra : float or numpy.ndarray The query right ascension (in degrees). query_dec : float or numpy.ndarray The query declination (in degrees). radius : float or None, optional The angular radius of the observation (in degrees). t_min : float or None, optional The minimum time (in MJD) for the observations to consider. If None, no time filtering is applied. t_max : float or None, optional The maximum time (in MJD) for the observations to consider. If None, no time filtering is applied. Returns ------- seen : bool or list[bool] Depending on the input, this is either a single bool to indicate whether the query point is observed or a list of bools for an array of query points. """ inds = self.range_search(query_ra, query_dec, radius=radius, t_min=t_min, t_max=t_max) if np.isscalar(query_ra): return len(inds) > 0 return [len(entry) > 0 for entry in inds]
[docs] def get_observations(self, query_ra, query_dec, *, radius=None, t_min=None, t_max=None, cols=None): """Return the observation information when the query point falls within the field of view of a pointing in the ObsTable. Parameters ---------- query_ra : float The query right ascension (in degrees). query_dec : float The query declination (in degrees). radius : float or None, optional The angular radius of the observation (in degrees). If None uses the default radius for the ObsTable. t_min : float or None, optional The minimum time (in MJD) for the observations to consider. If None, no time filtering is applied. t_max : float or None, optional The maximum time (in MJD) for the observations to consider. If None, no time filtering is applied. cols : list or str A list of the names of columns to extract or a single column name. If None returns all the columns. Returns ------- results : dict A dictionary mapping the given column name to a numpy array of values. """ neighbors = self.range_search(query_ra, query_dec, radius=radius, t_min=t_min, t_max=t_max) results = {} if cols is None: cols = self._table.columns.to_list() elif isinstance(cols, str): cols = [cols] for col in cols: # Allow the user to specify either the original or mapped column names, # by using the class accessor (__getitem__), instead of the table one. if col not in self: raise KeyError(f"Unrecognized column name {col}") results[col] = self[col].iloc[neighbors].to_numpy() return results
[docs] def compute_saturation(self, flux, flux_error, index): """Apply the saturation limits to a given flux and flux error. When a flux value exceeds the saturation limit, it is clipped to the limit and flagged as saturated. In these cases, the associated flux_error is increased to account for the offset introduced by clipping. The new error is computed as the quadrature sum of the original flux_error and the difference between the orginal flux and saturated flux:: saturated_flux_error = sqrt(flux_error**2 + (flux - saturated_flux)**2) For unsaturated points, both flux and flux_error are returned unchanged. Parameters ---------- flux : numpy.ndarray of float The bandflux in nJy. A size S x T array where S is the number of samples in the graph state and T is the number of time points. flux_error : numpy.ndarray of float The bandflux error in nJy. A size S x T array where S is the number of samples in the graph state and T is the number of time points. index : array_like of int The index of the observation in the ObsTable table. Returns ------- tuple of numpy.ndarray A tuple with three entries: - The saturated flux in nJy. A size S x T array where S is the number of samples in the graph state and T is the number of time points. - The saturated flux error in nJy. A size S x T array where S is the number of samples in the graph state and T is the number of time points. - A boolean array indicating which points are saturated. A size S x T array where S is the number of samples in the graph state and T is the number of time points. """ if self._saturation_mags is None: logger.info("Saturation thresholds not provided. Skipping saturation computation.") return flux, flux_error, np.full(flux.shape, False) true_flux = np.asarray(flux) true_flux_error = np.asarray(flux_error) filters = np.asarray(self._table["filter"].iloc[index]) if len(flux) != len(flux_error) or len(flux) != len(filters): raise ValueError("Input arrays must have the same length.") # Convert saturation thresholds to nJy. saturation_mags_njy = {} for filt, mag in self._saturation_mags.items(): if not isinstance(mag, int | float): raise ValueError("Saturation thresholds must be numeric.") saturation_mags_njy[filt] = mag2flux(mag) # Map the filter list to saturation limits. limits = np.array([saturation_mags_njy.get(filt, np.inf) for filt in filters]) # Calculate the saturated flux and flux error. saturated_flux = np.minimum(true_flux, limits) saturated_flux_error = np.hypot(true_flux_error, (true_flux - saturated_flux)) saturated_flux_error = np.where(true_flux <= limits, true_flux_error, saturated_flux_error) # Create a flag array to indicate which points are saturated. saturation_flags = true_flux > limits return saturated_flux, saturated_flux_error, saturation_flags
[docs] def make_resampled_table( self, times, *, ra=None, dec=None, filter=None, match_filter=False, seed=None, **kwargs, ): """Create a new ObsTable object that is resampled to the given times (and optionally positions and filters). All other columns, including noise data, are sampled from the existing table. By default the sampled rows are drawn from the entire table. However users can force the rows to be drawn from only those matching the given filter array by setting match_filter=True. Parameters ---------- times : array_like of float The times (in MJD) for the new observations. ra : float or array_like of float, optional The right ascension(s) (in degrees) for the new observations. If a single float is provided, it is used for all times. If None is provided, the existing RA values are kept. Default: None dec : float or array_like of float, optional The declination(s) (in degrees) for the new observations. If a single float is provided, it is used for all times. If None is provided, the existing Dec values are kept. Default: None filter : str or array_like of str, optional The filter(s) for the new observations. If a single string is provided, it is used for all times. If None is provided, the existing filter values are kept. Default: None match_filter : bool, optional If True, when sampling from the existing table, only sample observations that match the provided filter(s). This is only relevant if filter is provided. Default: False seed : int, optional A random seed for the resampling. If None, a random seed is used. Most users should use None here to get different results on each call. Default: None kwargs : dict, optional Additional keyword arguments to pass to the ObsTable constructor for the new object. These can overwrite any of the values from the original object. Returns ------- ObsTable A new ObsTable object with the resampled observations. """ rng = np.random.default_rng(seed) # Subsample the table to get observational data. times = np.asarray(times) num_samples = len(times) # If filters were given, set those up as an array of the correct size. if filter is not None: if isinstance(filter, str): filter = np.array([filter] * num_samples, dtype=object) elif len(filter) != num_samples: raise ValueError("If filter is an array, it must have the same length as times.") # Select the sample indices. If we are matching filters, we do this on a per-filter basis. # Otherwise we can just sample randomly from the whole table. if not match_filter or filter is None: sample_indices = rng.integers(0, len(self._table), size=num_samples) else: unique_filters = np.unique(filter) sample_indices = np.full(num_samples, -1, dtype=int) for current_filter in unique_filters: filter_sample_count = np.sum(filter == current_filter) matching_indices = np.where(self._table["filter"] == current_filter)[0] if len(matching_indices) == 0: raise ValueError(f"No observations found in filter '{current_filter}' to match.") chosen_indices = rng.choice(matching_indices, size=filter_sample_count, replace=True) sample_indices[filter == current_filter] = chosen_indices # Create the new table. Fill in the time (and optionally filter) column. new_table = self._table.iloc[sample_indices].reset_index(drop=True) new_table["time"] = times if filter is not None: new_table["filter"] = filter # If positions were given, overwrite those as well. if ra is not None: if dec is None: raise ValueError("If ra is provided, dec must also be provided.") if isinstance(ra, float): ra = np.full_like(times, ra) elif len(ra) != num_samples: raise ValueError("If ra is an array, it must have the same length as times.") new_table["ra"] = np.asarray(ra) if dec is not None: if ra is None: raise ValueError("If dec is provided, ra must also be provided.") if isinstance(dec, float): dec = np.full_like(times, dec) elif len(dec) != num_samples: raise ValueError("If dec is an array, it must have the same length as times.") new_table["dec"] = np.asarray(dec) # Create a copy of the kwargs and add anything that is missing. This allows the users # to override any of the survey values if desired. survey_kwargs = kwargs.copy() if "colmap" not in survey_kwargs: survey_kwargs["colmap"] = self._colmap if len(self._colmap) > 0 else None if "detector_footprint" not in survey_kwargs: survey_kwargs["detector_footprint"] = self._detector_footprint if "wcs" not in survey_kwargs: survey_kwargs["wcs"] = None # None because we have already converted to DetectorFootprint if "saturation_mags" not in survey_kwargs: survey_kwargs["saturation_mags"] = self._saturation_mags for key, value in self.survey_values.items(): if key not in survey_kwargs: survey_kwargs[key] = value # Create a new object using the correct subclass of ObsTable. result = self.__class__(new_table, **survey_kwargs) return result