"""Plotting utilities."""
from __future__ import annotations
import edges_cal.modelling as mdl
import matplotlib.pyplot as plt
import numpy as np
from astropy import units as apu
from .averaging import averaging
from .averaging.lstbin import lst_bin
from .datamodel import add_model
from .gsdata import GSData
from .gsdata.select import select_lsts
[docs]
def plot_waterfall(
data: GSData,
load: int = 0,
pol: int = 0,
which_flags: tuple[str] = None,
ignore_flags: tuple[str] = (),
ax: plt.Axes | None = None,
cbar: bool = True,
xlab: bool = True,
ylab: bool = True,
title: bool | str = True,
attribute: str = "data",
**imshow_kwargs,
):
"""Plot a waterfall from a GSData object.
Parameters
----------
data
The GSData object to plot.
load
The index of the load to plot (only one load is plotted).
pol
The polarization to plot (only one polarization is plotted).
which_flags
A tuple of flag names to use in order to mask the data. If None, all flags are
used. Send an empty tuple to ignore all flags.
ignore_flags
A tuple of flag names to ignore.
ax
The axis to plot on. If None, a new axis is created.
cbar
Whether to plot a colorbar.
xlab
Whether to plot an x-axis label.
ylab
Whether to plot a y-axis label.
title
Whether to plot a title. If True, the title is the year-day representation
of the dataset. If a string, use that as the title.
attribute
The attribute to actually plot. Can be any attribute of the data object that has
the same array shape as the primary data array. This includes "data",
"residuals", "complete_flags", "nsamples".
"""
q = getattr(data, attribute)
if not hasattr(q, "shape") or q.shape != data.data.shape:
raise ValueError(
f"Cannot use attribute '{attribute}' as it doesn't have "
"the same shape as data."
)
q = np.where(data.get_flagged_nsamples(which_flags, ignore_flags) == 0, np.nan, q)
q = q[load, pol, :, :]
if ax is None:
ax = plt.subplots(1, 1)[1]
if attribute == "resids":
cmap = imshow_kwargs.pop("cmap", "coolwarm")
else:
cmap = imshow_kwargs.pop("cmap", "magma")
times = data.time_array
if data.in_lst:
times = times.hour
times[times < times[0]] += 24
if times.max() > 36:
times -= 24
img = ax.imshow(
q,
origin="lower",
extent=(
data.freq_array.min().to_value("MHz"),
data.freq_array.max().to_value("MHz"),
times.min() if data.in_lst else 0,
(
times.max()
if data.in_lst
else (times.max() - times.min()).to_value("hour")
),
),
cmap=cmap,
aspect="auto",
interpolation="none",
**imshow_kwargs,
)
if xlab:
ax.set_xlabel("Frequency [MHz]")
if ylab:
if data.in_lst:
ax.set_ylabel("LST")
else:
ax.set_ylabel("Hours into Observation")
if title and not isinstance(title, str):
if not data.in_lst:
ax.set_title(
f"{data.get_initial_yearday()}. LST0={data.lst_array[0][0]:.2f}"
)
if title and isinstance(title, str):
ax.set_title(title)
if cbar:
cb = plt.colorbar(img, ax=ax)
cb.set_label(data.loads[load])
return ax
[docs]
def plot_time_average(
data: GSData,
ax: plt.Axes | None = None,
logy=None,
lst_min: float = 0,
lst_max: float = 24,
load: int = 0,
pol: int = 0,
attribute: str = "data",
offset: float = 0.0,
):
"""Make a 1D plot of the time-averaged data.
Parameters
----------
data
The GSData object to plot.
ax
The axis to plot on. If None, a new axis is created.
logy
Whether to plot a logarithmic y-axis. If None, the y-axis is logarithmic if all
the plotted data is positive.
lst_min
The minimum LST to average together.
lst_max
The maximum LST to average together.
load
The index of the load to plot (only one load is plotted).
pol
The polarization to plot (only one polarization is plotted).
attribute
The attribute to actually plot. Can be any attribute of the data object that has
the same array shape as the primary data array. This includes "data",
"residuals", "complete_flags", "nsamples".
offset
The offset to add to the data before plotting. Useful if plotting multiple
averages on the same axis.
"""
if ax is not None:
fig = ax.figure
else:
fig, ax = plt.subplots(1, 1)
if lst_min > 0 or lst_max < 24:
data = select_lsts(data, lst_range=(lst_min, lst_max))
data = lst_bin(data, binsize=24.0)
q = getattr(data, attribute)
if not hasattr(q, "shape") or q.shape != data.data.shape:
raise ValueError(
f"Cannot use attribute '{attribute}' as it doesn't "
"have the same shape as data."
)
ax.plot(data.freq_array, q[load, pol, 0] - offset)
ax.set_xlabel("Frequency [MHz]")
ax.set_ylabel("Average Spectrum")
if logy is None:
logy = np.all(q > 0)
if logy:
ax.set_yscale("log")
return ax, data
[docs]
def plot_daily_residuals(
objs: list[GSData],
model: mdl.Model | None = None,
separation: float = 20.0,
ax: plt.Axes | None = None,
load: int = 0,
pol: int = 0,
**kw,
) -> plt.Axes:
"""
Make a single plot of residuals for each object.
Parameters
----------
objs
A list of objects to plot.
separation
The separation between residuals in K (on the plot).
Other Parameters
----------------
All other parameters are passed through to :func:`plot_time_average`.
Returns
-------
ax
The matplotlib Axes on which the plot is made.
"""
if ax is None:
fig, ax = plt.subplots(1, 1)
for i, data in enumerate(objs):
if data.residuals is None and model is None:
raise ValueError("If data has no model, must provide one!")
if data.residuals is None:
data = add_model(data, model=model)
ax, d = plot_time_average(
data, attribute="residuals", offset=separation * i, ax=ax, **kw
)
rms = np.sqrt(
averaging.weighted_mean(
data=d.resids[load, pol, 0] ** 2, weights=d.nsamples[load, pol, 0]
)[0]
)
title = data.filename.name if data.in_lst else data.get_initial_yearday()
ax.text(
data.freq_array.max() + 5 * apu.MHz,
-i * separation,
f"{title} RMS={rms:.2f}",
)
return ax