Source code for edges.analysis.plots

"""Plotting utilities."""

import matplotlib.pyplot as plt
import numpy as np
from astropy import units as apu
from pygsdata import GSData
from pygsdata.select import select_lsts

from .. import modeling as mdl
from ..averaging import averaging
from ..averaging.lstbin import average_over_times
from .datamodel import add_model


[docs] def plot_time_average( data: GSData, ax: plt.Axes | None = None, logy=None, lst_min: float = 0, lst_max: float = 24, load: int = 0, pol: int = 0, attribute: str = "data", offset: float = 0.0, ): """Make a 1D plot of the time-averaged data. Parameters ---------- data The GSData object to plot. ax The axis to plot on. If None, a new axis is created. logy Whether to plot a logarithmic y-axis. If None, the y-axis is logarithmic if all the plotted data is positive. lst_min The minimum LST to average together. lst_max The maximum LST to average together. load The index of the load to plot (only one load is plotted). pol The polarization to plot (only one polarization is plotted). attribute The attribute to actually plot. Can be any attribute of the data object that has the same array shape as the primary data array. This includes "data", "residuals", "complete_flags", "nsamples". offset The offset to add to the data before plotting. Useful if plotting multiple averages on the same axis. """ if ax is not None: pass else: _fig, ax = plt.subplots(1, 1) if lst_min > 0 or lst_max < 24: data = select_lsts(data, lst_range=(lst_min, lst_max)) data = average_over_times(data) q = getattr(data, attribute) if not hasattr(q, "shape") or q.shape != data.data.shape: raise ValueError( f"Cannot use attribute '{attribute}' as it doesn't " "have the same shape as data." ) ax.plot(data.freqs, q[load, pol, 0] - offset) ax.set_xlabel("Frequency [MHz]") ax.set_ylabel("Average Spectrum") if logy is None: logy = np.all(q > 0) if logy: ax.set_yscale("log") return ax, data
[docs] def plot_daily_residuals( objs: list[GSData], model: mdl.Model | None = None, separation: float = 20.0, ax: plt.Axes | None = None, load: int = 0, pol: int = 0, **kw, ) -> plt.Axes: """ Make a single plot of residuals for each object. Parameters ---------- objs A list of objects to plot. separation The separation between residuals in K (on the plot). Other Parameters ---------------- All other parameters are passed through to :func:`plot_time_average`. Returns ------- ax The matplotlib Axes on which the plot is made. """ if ax is None: _fig, ax = plt.subplots(1, 1) for i, data in enumerate(objs): if data.residuals is None and model is None: raise ValueError("If data has no model, must provide one!") if data.residuals is None: data = add_model(data, model=model) ax, d = plot_time_average( data, attribute="residuals", offset=separation * i, ax=ax, **kw ) rms = np.sqrt( averaging.weighted_mean( data=d.residuals[load, pol, 0] ** 2, weights=d.nsamples[load, pol, 0] )[0] ) title = data.get_initial_yearday() ax.text( data.freqs.max() + 5 * apu.MHz, -i * separation, f"{title} RMS={rms:.2f}", ) return ax