"""A collection of functions for plotting and visualization."""
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
from lightcurvelynx.astro_utils.mag_flux import flux2mag
def _build_colormap(unique_filters):
"""Construct a colormap for a given set of filters."""
filter_list = list(unique_filters)
n_filters = len(filter_list)
# Use different colormaps depending on the number of filters.
if n_filters <= 10:
cmap = plt.get_cmap("tab10", 10)
elif n_filters <= 20:
cmap = plt.get_cmap("tab20", 20)
else:
cmap = plt.get_cmap("turbo", n_filters)
colormap = {f: cmap(i) for i, f in enumerate(filter_list)}
return colormap
[docs]
def plot_lightcurves(
fluxes,
times,
*,
fluxerrs=None,
filters=None,
underlying_model=None,
ax=None,
figure=None,
title=None,
colormap=None,
plot_magnitudes=False,
**kwargs,
):
"""Plot one or more light curves.
Parameters
----------
fluxes : numpy.ndarray
An array of T flux values.
times : numpy.ndarray
A length T matrix of the times, used for setting the x axis.
All times in MJD.
fluxerrs : numpy.ndarray or None, optional
A length T matrix of errors on the fluxes for error bars. If not provided
no error bars are created. None by default.
filters : numpy.ndarray or None, optional
A length T matrix of filter names. If not provided all points are
treated as coming from the same filter. None by default.
underlying_model: dict or None, optional
A dictionary mapping filter names to the noise free light curves for this model.
If provided, these curves will be plotted as lines behind the data points.
None by default.
ax : matplotlib.pyplot.Axes or None, optional
Axes, None by default.
figure : matplotlib.pyplot.Figure or None
Figure, None by default.
title : str or None, optional
Title of the plot. None by default.
colormap: dict, optional
A dictionary that provides mapping between filters and the colors to be plotted.
plot_magnitudes : bool, optional
Whether to plot magnitudes instead of fluxes. False by default.
**kwargs : dict
Optional parameters to pass to the plotting function
Returns
-------
ax : matplotlib.pyplot.Axes
The axes containing the plot.
"""
# If no axes were given create them using either the given figure or
# a newly created one (if no figure is given).
if ax is None:
if figure is None:
figure = plt.figure()
ax = figure.add_axes([0, 0, 1, 1])
# If we are plotting in magnitudes, convert the fluxes and flux errors to magnitudes
# and magnitude errors. We filter out any invalid points from times, filters, fluxes,
# and fluxerrs if they are given. We also convert the underlying model to magnitudes
# if it is provided.
if plot_magnitudes:
valid_mask = fluxes > 0
if filters is not None:
filters = filters[valid_mask]
if fluxerrs is not None:
fluxerrs = (2.5 / np.log(10)) * (fluxerrs[valid_mask] / fluxes[valid_mask])
fluxes = flux2mag(fluxes[valid_mask])
times = times[valid_mask]
# Convert the underlying model to magnitudes if it is provided.
if underlying_model is not None:
for key in underlying_model:
if key != "times":
underlying_model[key] = flux2mag(underlying_model[key])
# Set up the time array if it is not given.
num_pts = len(fluxes)
if len(times) != num_pts:
raise ValueError(f"Mismatched array sizes for fluxes ({num_pts}) and times ({len(times)}).")
# Set up a list of filters to display.
if filters is None:
filters = ["none"] * num_pts
unique_filters = set(["None"])
elif len(filters) == num_pts:
filters = np.asarray(filters)
unique_filters = np.unique(filters)
else:
raise ValueError(f"Mismatched array sizes for fluxes ({num_pts}) and filters ({len(filters)}).")
# Check that if flux errors are given, they are the correct size.
if fluxerrs is not None and len(fluxerrs) != num_pts:
raise ValueError(f"Mismatched array sizes for fluxes ({num_pts}) and fluxerrs ({len(fluxerrs)}).")
if colormap is None:
colormap = _build_colormap(unique_filters)
# Plot the data with one line for each filter.
for filter in unique_filters:
filter_mask = filters == filter
# Plot the underlying model if it is provided.
if underlying_model is not None and filter in underlying_model:
ax.plot(
underlying_model["times"],
underlying_model[filter],
linestyle="-",
color=colormap[filter],
alpha=0.5,
label=f"Model {filter}",
)
if fluxerrs is None:
ax.plot(
times[filter_mask],
fluxes[filter_mask],
marker="o",
label=f"Sample {filter}",
color=colormap[filter],
**kwargs,
)
else:
ax.errorbar(
times[filter_mask],
fluxes[filter_mask],
yerr=fluxerrs[filter_mask],
fmt="o",
label=f"Sample {filter}",
color=colormap[filter],
**kwargs,
)
# Set the title and axis labels.
if title is not None:
ax.set_title(title)
ax.set_xlabel("Time (MJD)")
if plot_magnitudes:
ax.set_ylabel("Magnitude")
ax.invert_yaxis() # Invert y-axis for magnitudes since lower magnitudes are brighter.
else:
ax.set_ylabel("Flux (nJy)")
# Only include a legend if there are at least two curves.
if len(unique_filters) > 1:
ax.legend()
return ax
[docs]
def plot_bandflux_lightcurves(bandflux, times=None, ax=None, figure=None, title=None):
"""Plot one or more light curves where each band is observed at each time.
This is primarily used for visualizing non-sampled data.
Parameters
----------
bandflux : numpy.ndarray or dict
Either a single array with the light curve or a dictionary mapping
light curve names to the arrays of values.
times : numpy.ndarray or None, optional
A length T matrix of the times, used for setting the x axis. If not
provided, uses equal spaced ticks. None by default.
ax : matplotlib.pyplot.Axes or None, optional
Axes, None by default.
figure : matplotlib.pyplot.Figure or None
Figure, None by default.
title : str or None, optional
Title of the plot. None by default.
Returns
-------
ax : matplotlib.pyplot.Axes
The axes containing the plot.
"""
# If no axes were given create them using either the given figure or
# a newly created one (if no figure is given).
if ax is None:
if figure is None:
figure = plt.figure()
ax = figure.add_axes([0, 0, 1, 1])
# Plot the data.
if isinstance(bandflux, np.ndarray):
bandflux = {"lightcurve": bandflux}
for name, curve in bandflux.items():
if times is None:
times = np.arange(len(curve))
ax.plot(times, curve, marker="o", label=name)
# Set the title and axis labels.
if title is not None:
ax.set_title(title)
ax.set_xlabel("Time (days)")
ax.set_ylabel("Flux")
# Only include a legend if there are at least two curves.
if len(bandflux) > 1:
ax.legend()
return ax
[docs]
def plot_flux_spectrogram(flux_density, times=None, wavelengths=None, ax=None, figure=None, title=None):
"""Plot a spectrogram to visualize the fluxes.
Parameters
----------
flux_density : numpy.ndarray
A length T x N matrix of SED values (in nJy), where T is the number of time steps,
and N is the number of wavelengths.
times : numpy.ndarray or None, optional
A length T matrix of the times, used for setting the x axis. If not
provided, uses equal spaced ticks. None by default.
wavelengths : numpy.ndarray or None, optional
A length N matrix of the times, used for setting the y axis. If not
provided, uses equal spaced ticks. None by default.
ax : matplotlib.pyplot.Axes or None, optional
Axes, None by default.
figure : matplotlib.pyplot.Figure or None
Figure, None by default.
title : str or None, optional
Title of the plot. None by default.
Returns
-------
ax : matplotlib.pyplot.Axes
The axes containing the plot.
"""
# If no axes were given create them using either the given figure or
# a newly created one (if no figure is given).
if ax is None:
if figure is None:
figure = plt.figure()
ax = figure.add_axes([0, 0, 1, 1])
ax.imshow(flux_density.T, cmap="plasma", interpolation="nearest", aspect="auto")
# Add title, axis labels, and correct ticks
if title is None:
ax.set_title("Flux Spectrogram")
else:
ax.set_title(title)
if times is not None:
ax.set_xlabel("Time (days)")
ax.set_xticks(np.arange(len(times))[::4], [f"{round(time)}" for time in times][::4])
if wavelengths is not None:
ax.set_ylabel("Wavelength (Angstrom)")
ax.set_yticks(np.arange(len(wavelengths))[::50], [f"{round(wave)}" for wave in wavelengths][::50])
# Add flux labels
for (j, i), label in np.ndenumerate(flux_density.T):
if i % 2 == 1 and j % 40 == 20:
ax.text(i, j, round(label, 1), ha="center", va="center", size=8)
return ax
[docs]
def plot_moc(
moc,
*,
fig=None,
ax=None,
**kwargs,
):
"""Plot a Multi-Order Coverage (MOC) map of the sky.
Parameters
----------
moc : mocpy.MOC
The MOC object to plot.
fig : matplotlib.pyplot.Figure or None, optional
The figure to use for the plot. If None, a new figure will be created.
ax : matplotlib.pyplot.Axes or None, optional
The axes to use for the plot. If None, new axes will be created on the figure.
kwargs : dict
Additional keyword arguments to pass to the moc.fill() function for customizing the plot.
Returns
-------
fig: matplotlib.figure.Figure
The figure containing the plot.
ax: matplotlib.pyplot.Axes
The axes containing the plot.
"""
from astropy.coordinates import Angle, SkyCoord
from astropy.visualization.wcsaxes import WCSAxes
from astropy.visualization.wcsaxes.frame import EllipticalFrame
from mocpy import WCS
if fig is None:
# If we are given an axes, use the figure from the axes. Otherwise, create a new figure.
fig = plt.figure() if ax is None else ax.get_figure()
if ax is not None:
# If an axis is given, we use that to determine the WCS and frame type.
if not isinstance(ax, WCSAxes):
raise ValueError("If ax is given, it must be a WCSAxes.")
wcs = ax.wcs
frame_type = type(ax.coords.frame)
else:
# We create a WCS that covers the full sky in a Mollweide projection.
wcs = WCS(
fig,
fov=(320 * u.deg, 160 * u.deg),
center=SkyCoord(0, 0, unit="deg", frame="icrs"),
coordsys="icrs",
rotation=Angle(0, u.deg),
projection="MOL",
).w
frame_type = EllipticalFrame
# Create the axes if they were not given.
if ax is None:
ax = fig.add_subplot(projection=wcs, frame_class=frame_type)
ax.coords[0].set_format_unit("deg")
# Start with some basic plotting arguments that can be overridden by kwargs.
mocpy_args = {
"alpha": 0.5,
"edgecolor": "darkblue",
"facecolor": "blue",
"fill": True,
}
mocpy_args.update(**kwargs)
# Plot the MOC.
moc.fill(ax, wcs, **mocpy_args)
plt.grid()
plt.ylabel("Dec")
plt.xlabel("RA")
return fig, ax