Source code for lightcurvelynx.models.eztaox_models

"""Wrappers for the models defined in EZTaoX.

https://github.com/LSST-AGN-Variability/EzTaoX
"""

import importlib

import numpy as np
from citation_compass import CiteClass

from lightcurvelynx.astro_utils.mag_flux import mag2flux
from lightcurvelynx.math_nodes.np_random import NumpyRandomFunc
from lightcurvelynx.models.physical_model import BandfluxModel


[docs] class EzTaoXWrapperModel(BandfluxModel, CiteClass): """A wrapper for an eztaox model. Parameterized values include: * dec - The object's declination in degrees. [from BasePhysicalModel] * distance - The object's luminosity distance in pc. [from BasePhysicalModel] * ra - The object's right ascension in degrees. [from BasePhysicalModel] * redshift - The object's redshift. [from BasePhysicalModel] * t0 - The t0 of the zero phase, date. [from BasePhysicalModel] Additional parameterized values are used for specific eztaox models. References ---------- * Weixiang Yu et al., 2025 “Scalable and Robust Multiband Modeling of AGN Light Curves in Rubin-LSST" DOI: 10.48550/arXiv.2511.21479 Attributes ---------- kernel : eztaox kernel object An eztaox kernel object to use for the Gaussian process modeling of the light curve. num_kernel_params : int The number of kernel parameters for the eztaox kernel. num_filters : int The number of filters in the model. zero_mean : bool Whether to use a zero mean model. has_lag : bool Whether the model includes lag parameters. filter_idx : dict A mapping from filter name to an integer index. Parameters ---------- kernel : eztaox kernel object An eztaox kernel object to use for the Gaussian process modeling of the light curve. baseline_mags : dict, optional A mapping from filter name to the setter baseline magnitude for that filter. If not provided, the model will use zero mean or the mean_func/mean_mag parameters. Default is None. log_kernel_param : list of setters, required Setters for each of the log kernel parameters. These must be in the order expected by the kernel functions. amp_scale_func : Callable, optional A callable amplitude scaling function, defaults to None. log_amp_scale : list of setters, optional Setters for the log amplitude scale for each filter (length N) if amp_scale_func is not provided. Default is None. zero_mean : bool Whether to use a zero mean model. If False then the program will try to use (in order): mean_func, mean (values), or a default mean function. Default is True. mean_func : Callable, optional If provided and zero_mean is False this is used to compute the mean function for bands. Default is None. mean_mag : list of setters, optional Setters for the mean magnitude for each filter except the first (length N-1) if zero_mean is False and mean_func is not provided. Default is None. has_lag : bool Whether the model includes lag parameters. If True then the program will try to use (in order): lag_func, lag (values), or a default lag function. Default is False. lag_func : Callable, optional If provided and has_lag is True this is used to compute the lag. Default is None. lag : list of setters, optional Setters for the lag for each filter except the first (length N) if has_lag is True and lag_func is not provided. Default is None. band_list : list, optional A list of band names in order. If not provided, the default list for ugrizy filters is used. seed_param : setter, optional A setter for the seed parameter to use for each run. If not provided, a random seed is generated for each run. Default is None. **kwargs : dict, optional Any additional keyword arguments. """ # Convenience mapping from filter name to index in the parameter list. _default_filter_idx = {"u": 0, "g": 1, "r": 2, "i": 3, "z": 4, "y": 5} def __init__( self, kernel, *, baseline_mags=None, log_kernel_param=None, amp_scale_func=None, log_amp_scale=None, zero_mean=True, mean_func=None, mean_mag=None, has_lag=False, lag_func=None, lag=None, band_list=None, seed_param=None, **kwargs, ): self._cached_data = {} super().__init__(**kwargs) # Confirm that the needed packages are installed. if importlib.util.find_spec("eztaox") is None: # pragma: no cover raise ImportError( "The EzTaoX package is required to use the EzTaoXWrapperModel. " "Please install it from https://github.com/LSST-AGN-Variability/EzTaoX" ) if importlib.util.find_spec("jax") is None: # pragma: no cover raise ImportError( "JAX is required to use the EzTaoXWrapperModel class, please " "install with `pip install jax` or `conda install conda-forge::jax`" ) # Store the kernel and filter index mapping.
[docs] self.kernel = kernel
if band_list is not None: self.filter_idx = {band: idx for idx, band in enumerate(band_list)} else: self.filter_idx = self._default_filter_idx
[docs] self.num_filters = len(self.filter_idx)
# Add the log kernel parameters to this model node. Since the order is # the important aspect, we just name them by index. if log_kernel_param is None: raise ValueError("The log_kernel_param parameter setters must be provided.")
[docs] self.num_kernel_params = len(log_kernel_param)
for i, setter in enumerate(log_kernel_param): self.add_parameter(f"eztaox_log_kernel_param_{i}", setter) # Add the amplitude scale information as either a function or parameters to this node. self._amp_scale_func = amp_scale_func if log_amp_scale is None and amp_scale_func is None: raise ValueError("One of either amp_scale_func or log_amp_scale must be provided.") self._has_log_amp_scale = log_amp_scale is not None if self._has_log_amp_scale: for i, setter in enumerate(log_amp_scale): self.add_parameter(f"eztaox_log_amp_scale_{i}", setter) if len(log_amp_scale) != self.num_filters: raise ValueError( f"The number of log amplitude scale parameter setters {len(log_amp_scale)} " f"must be equal to the number of filters {self.num_filters}." ) # Store the mean magnitude parameters and callable if provided.
[docs] self.zero_mean = zero_mean
self._mean_func = mean_func # The callable if it is provided. self._has_mean_vals = mean_mag is not None if self._has_mean_vals: if baseline_mags is not None: raise ValueError( "If mean_mag parameter setters are provided, then baseline_mags cannot also be provided." ) if len(mean_mag) != self.num_filters - 1: raise ValueError( f"The number of mean parameter setters {len(mean_mag)} must be equal to the " f"number of filters minus one {self.num_filters - 1}." ) for i, setter in enumerate(mean_mag): self.add_parameter(f"eztaox_band_mean_{i}", setter) # Store the lag parameters if provided.
[docs] self.has_lag = has_lag
self._lag_func = lag_func # The callable if it is provided. self._has_lag_values = lag is not None if self._has_lag_values: if len(lag) != self.num_filters: raise ValueError( f"The number of lag parameter setters {len(lag)} must be equal to the " f"number of filters {self.num_filters}." ) for i, setter in enumerate(lag): self.add_parameter(f"eztaox_band_lag_{i}", setter) # Store the baseline magnitude information. for filter in self.filter_idx: param_name = f"eztaox_baseline_mag_{filter}" if baseline_mags is not None and filter in baseline_mags: self.add_parameter(param_name, baseline_mags[filter]) else: self.add_parameter(param_name, 0.0) # Default to 0.0 if not provided. # The seed used per run can be defined by the seed_param. If is not provided, # we randomly generate a seed for each run. if seed_param is None: seed_param = NumpyRandomFunc("integers", low=0, high=2**32 - 1) self.add_parameter("eztaox_seed_param", seed_param) def _compute_all_bandfluxes(self, times, filters, state): """Evaluate the model at the passband level for a single, given graph state and and all of the filters at once. We do this and cache the results to avoid recomputing the same model for each band. Parameters ---------- times : numpy.ndarray A length T array of observer frame timestamps in MJD. filters : list of str The names of the filters (one at each time). state : GraphState An object mapping graph parameters to their values with num_samples=1. This is not used in this model, but is required for the function signature. """ if ( "bandfluxes" in self._cached_data and state is self._cached_data.get("state") and np.array_equal(self._cached_data["filters"], filters) and np.array_equal(self._cached_data["times"], times) ): # pragma: no cover return # Nothing to do, the cache is valid. # Cache the input data (times and filters). self._cached_data = {} # Clear the cache (just in case). self._cached_data["times"] = times self._cached_data["filters"] = filters self._cached_data["state"] = state # Import the dependencies that we need for this computation (these have all # been checked in the constructor). import jax import jax.numpy as jnp from eztaox.simulator import MultiVarSim # Shift the times to be relative to start at 0. delta_t = jnp.array(times) - jnp.min(jnp.array(times)) # Shift times to start at 0. # Extract the local parameters for this object from the full state object and build # the parameter dict needed by the simulator. This parameter dict must include: # - log_kernel_param a JAX numpy array of shape (num_kernel_params,) # - log_amp_scale a JAX numpy array of shape (num_filters,) # Optional it may also include: # - mean a JAX numpy array of shape (num_filters - 1,) # - lag a JAX numpy array of shape (num_filters,) local_params = self.get_local_params(state) init_params = {} init_params["log_kernel_param"] = jnp.array( [local_params[f"eztaox_log_kernel_param_{i}"] for i in range(self.num_kernel_params)] ) if self._has_log_amp_scale: init_params["log_amp_scale"] = jnp.array( [local_params[f"eztaox_log_amp_scale_{i}"] for i in range(self.num_filters)] ) if self._has_mean_vals: init_params["mean"] = jnp.array( [local_params[f"eztaox_band_mean_{i}"] for i in range(self.num_filters - 1)] ) if self._has_lag_values: init_params["lag"] = jnp.array( [local_params[f"eztaox_band_lag_{i}"] for i in range(self.num_filters)] ) # Create the simulator object using the given parameters and run it. sim = MultiVarSim( self.kernel, 0.01, jnp.max(delta_t), # The last time to simulate. self.num_filters, init_params=init_params, mean_func=self._mean_func, amp_scale_func=self._amp_scale_func, lag_func=self._lag_func, zero_mean=self.zero_mean, has_lag=self.has_lag, ) # Compute the list of bands as integer indices for the simulator and save them to the cache. band_indices = jnp.array([self.filter_idx[f] for f in filters]) # Compute the list of magnitudes for all given times, transform them to fluxes, # save them in the cache, and return them. _, mags = sim.fixed_input_fast( (delta_t, band_indices), # Tuple of times and band indices jax.random.PRNGKey(local_params["eztaox_seed_param"]), # Use the per-run seed. ) # Add in the baseline magnitudes. Note we need to create a new copy of the array here, # because it is a JAX array and cannot be modified in place. mags = np.array(mags) for filter in self.filter_idx: baseline_mag = local_params[f"eztaox_baseline_mag_{filter}"] filter_mask = filters == filter mags[filter_mask] += baseline_mag # Convert to fluxes and cache. bandfluxes = mag2flux(mags) self._cached_data["bandfluxes"] = bandfluxes
[docs] def compute_bandflux(self, times, filter, state): """Evaluate the model at the passband level for a single, given graph state and filter. Parameters ---------- times : numpy.ndarray A length T array of observer frame timestamps in MJD. filter : str The name of the filter. state : GraphState An object mapping graph parameters to their values with num_samples=1. This is not used in this model, but is required for the function signature. Returns ------- bandflux : numpy.ndarray A length T array of band fluxes for this model in this filter. """ if self._cached_data is None or state is not self._cached_data.get("state"): raise NotImplementedError( "The compute_bandflux method should not be called directly for the " "EzTaoXWrapperModel. Instead, use the evaluate_bandfluxes method which " "handles caching of the bandflux computations." ) # Extract the bandfluxes from the entries with matching filters. filter_mask = self._cached_data["filters"] == filter return self._cached_data["bandfluxes"][filter_mask]
[docs] def evaluate_bandfluxes(self, passband_or_group, times, filters, state, rng_info=None) -> np.ndarray: """Get the band fluxes for a given Passband or PassbandGroup. Parameters ---------- passband_or_group : Passband or PassbandGroup or None The passband (or passband group) to use. Not used in this function. times : numpy.ndarray A length T array of observer frame timestamps in MJD. filters : numpy.ndarray or None A length T array of filter names. It may be None if passband_or_group is a Passband. state : GraphState or None An object mapping graph parameters to their values. rng_info : numpy.random._generator.Generator, optional A given numpy random number generator to use for this computation. If not provided, the function uses the node's random number generator. Returns ------- bandfluxes : numpy.ndarray A matrix of the band fluxes. If only one sample is provided in the GraphState, then returns a length T array. Otherwise returns a size S x T array where S is the number of samples in the graph state. """ # Check if we need to sample the graph. if state is None: state = self.sample_parameters(num_samples=1, rng_info=rng_info) # Compute the bandfluxes using the cached method. self._compute_all_bandfluxes(times, filters, state) # Call the parent method to extract the bandfluxes in the right format. bandfluxes = super().evaluate_bandfluxes(passband_or_group, times, filters, state, rng_info=rng_info) # Clear the cache self._cached_data.clear() return bandfluxes