Source code for edges_analysis.gsdata.gsdata

"""
A module containing the class GSData, a variant of UVData specific to single antennas.

The GSData object simplifies handling of radio astronomy data taken from a single
antenna, adding self-consistent metadata along with the data itself, and providing
key methods for data selection, I/O, and analysis.
"""

from __future__ import annotations

import astropy.units as un
import h5py
import hickle
import logging
import numpy as np
import warnings
from astropy.coordinates import EarthLocation, Longitude, UnknownSiteException
from astropy.time import Time
from attrs import converters as cnv
from attrs import define, evolve, field
from attrs import validators as vld
from functools import cached_property
from pathlib import Path
from read_acq.read_acq import ACQError
from typing import Iterable, Literal

from .. import coordinates as crd
from .attrs import npfield, timefield
from .constants import KNOWN_LOCATIONS
from .gsflag import GSFlag
from .history import History, Stamp

logger = logging.getLogger(__name__)


[docs] @define(slots=False) class GSData: """A generic container for Global-Signal data. Parameters ---------- data The data array (i.e. what the telescope measures). This must be a 4D array whose dimensions are (load, polarization, time, frequency). The data can be raw powers, calibrated temperatures, or even model residuals to such. Their type is specified by the ``data_unit`` attribute. freq_array The frequency array. This must be a 1D array of frequencies specified as an astropy Quantity. time_array The time array. This must be a 2D array of shape (times, loads). It can be in one of two formats: either an astropy Time object, specifying the absolute time, or an astropy Longitude object, specying the LSTs. In "lst" mode, there are many methods that become unavailable. telescope_location The telescope location. This must be an astropy EarthLocation object. loads The names of the loads. Usually there is a single load ("ant"), but arbitrary loads may be specified. nsamples An array with the same shape as the data array, specifying the number of samples that go into each data point. This is unitless, and can be used with the ``effective_integration_time`` attribute to compute the total effective integration time going into any measurement. effective_integration_time An astropy Quantity that specifies the amount of time going into a single "sample" of the data. flags A dictionary mapping filter names to boolean arrays. Each boolean array has the same shape as the data array, and is True where the data is flagged. history A tuple of dictionaries, each of which is a record of a previous processing step. telescope_name The name of the telescope. residuals An optional array of the same shape as data that holds the residuals of a model fit to the data. auxiliary_measurements A dictionary mapping measurement names to arrays. Each array must have its leading axis be the same length as the time array. filename The filename from which the data was read (if any). Used for writing additional data if more is added (eg. flags, data model). """ data: np.ndarray = npfield(dtype=float, possible_ndims=(4,)) freq_array: un.Quantity[un.MHz] = npfield(possible_ndims=(1,), unit=un.MHz) time_array: Time | Longitude = timefield(possible_ndims=(2,)) telescope_location: EarthLocation = field( validator=vld.instance_of(EarthLocation), converter=lambda x: ( EarthLocation(*x) if not isinstance(x, EarthLocation) else x ), ) loads: tuple[str] = field(converter=tuple) nsamples: np.ndarray = npfield(dtype=float, possible_ndims=(4,)) effective_integration_time: un.Quantity[un.s] = field(default=1 * un.s) flags: dict[str, GSFlag] = field(factory=dict) history: History = field( factory=History, validator=vld.instance_of(History), eq=False ) telescope_name: str = field(default="unknown") residuals: np.ndarray | None = npfield( default=None, possible_ndims=(4,), dtype=float ) data_unit: Literal["power", "temperature", "uncalibrated", "uncalibrated_temp"] = ( field(default="power") ) auxiliary_measurements: dict = field(factory=dict) time_ranges: Time | Longitude = timefield(shape=(None, None, 2)) filename: Path | None = field(default=None, converter=cnv.optional(Path)) _file_appendable: bool = field(default=True, converter=bool) name: str = field(default="", converter=str) @nsamples.validator def _nsamples_validator(self, attribute, value): if value.shape != self.data.shape: raise ValueError("nsamples must have the same shape as data") @nsamples.default def _nsamples_default(self) -> np.ndarray: return np.ones_like(self.data) @flags.validator def _flags_validator(self, attribute, value): if not isinstance(value, dict): raise TypeError("flags must be a dict") for key, flag in value.items(): if not isinstance(flag, GSFlag): raise TypeError("flags values must be GSFlag instances") flag._check_compat(self) if not isinstance(key, str): raise ValueError("flags keys must be strings") @residuals.validator def _residuals_validator(self, attribute, value): if value is not None and value.shape != self.data.shape: raise ValueError("residuals must have the same shape as data") @freq_array.validator def _freq_array_validator(self, attribute, value): if value.shape != (self.nfreqs,): raise ValueError( "freq_array must have the size nfreqs. " f"Got {value.shape} instead of {self.nfreqs}" ) @time_array.validator def _time_array_validator(self, attribute, value): if value.shape != (self.ntimes, self.nloads): raise ValueError( f"time_array must have the size (ntimes, nloads), got {value.shape} " f"instead of {(self.ntimes, self.nloads)}" ) @time_ranges.default def _time_ranges_default(self): if self.in_lst: return Longitude( np.concatenate( ( self.time_array.hour[:, :, None], self.time_array.hour[:, :, None] + self.effective_integration_time.to_value("hour"), ), axis=-1, ) * un.hour ) else: return Time( np.concatenate( ( self.time_array.jd[:, :, None], self.time_array.jd[:, :, None] + self.effective_integration_time.to_value("day"), ), axis=-1, ), format="jd", ) @time_ranges.validator def _time_ranges_validator(self, attribute, value): if value.shape != (self.ntimes, self.nloads, 2): raise ValueError( f"time_ranges must have the size (ntimes, nloads, 2), got {value.shape}" f" instead of {(self.ntimes, self.nloads, 2)}." ) if not self.in_lst and not np.all(value[..., 1] - value[..., 0] > 0): # TODO: properly check lst-type input, which can wrap... raise ValueError("time_ranges must all be greater than zero") @loads.default def _loads_default(self) -> tuple[str]: if self.data.shape[0] == 1: return ("ant",) elif self.data.shape[0] == 3: return ("ant", "internal_load", "internal_load_plus_noise_source") else: raise ValueError( "If data has more than one source, loads must be specified" ) @loads.validator def _loads_validator(self, attribute, value): if len(value) != self.data.shape[0]: raise ValueError( "loads must have the same length as the number of loads in data" ) if not all(isinstance(x, str) for x in value): raise ValueError("loads must be a tuple of strings") @effective_integration_time.validator def _effective_integration_time_validator(self, attribute, value): if not isinstance(value, un.Quantity): raise TypeError("effective_integration_time must be a Quantity") if not value.unit.is_equivalent("s"): raise ValueError("effective_integration_time must be in seconds") @auxiliary_measurements.validator def _aux_meas_vld(self, attribute, value): if not isinstance(value, dict): raise TypeError("auxiliary_measurements must be a dictionary") if isinstance(self.time_array, Longitude) and value: raise ValueError( "If times are LSTs, auxiliary_measurements cannot be specified" ) for key, val in value.items(): if not isinstance(key, str): raise TypeError("auxiliary_measurements keys must be strings") if not isinstance(val, np.ndarray): raise TypeError("auxiliary_measurements values must be arrays") if val.shape[0] != self.ntimes: raise ValueError( "auxiliary_measurements values must have the size ntimes " f"({self.ntimes}), but for {key} got shape {val.shape}" ) @data_unit.validator def _data_unit_validator(self, attribute, value): if value not in ( "power", "temperature", "uncalibrated", "uncalibrated_temp", ): raise ValueError( 'data_unit must be one of "power", "temperature", "uncalibrated",' '"uncalibrated_temp"' ) @property def nfreqs(self) -> int: """The number of frequency channels.""" return self.data.shape[-1] @property def nloads(self) -> int: """The number of loads.""" return self.data.shape[0] @property def ntimes(self) -> int: """The number of times.""" return self.data.shape[-2] @property def npols(self) -> int: """The number of polarizations.""" return self.data.shape[1] @property def model(self) -> np.ndarray | None: """The model of the data.""" if self.residuals is None: return None return self.data - self.residuals @property def resids(self) -> np.ndarray | None: """The residuals of the data.""" warnings.warn( DeprecationWarning("Use the 'residuals' attribute instead of 'resids'") ) return self.residuals
[docs] @classmethod def read_acq( cls, filename: str | Path, telescope_location: str | EarthLocation, name="{year}_{day}", **kw, ) -> GSData: """Read an ACQ file.""" filename = Path(filename) try: from read_acq import read_acq except ImportError as e: raise ImportError( "read_acq is not installed -- install it to read ACQ files" ) from e _, (pant, pload, plns), anc = read_acq.decode_file(filename, meta=True) if pant.size == 0: raise ACQError(f"No data in file {filename}") times = Time(anc.data.pop("times"), format="yday", scale="utc") if isinstance(telescope_location, str): try: telescope_location = EarthLocation.of_site(telescope_location) except UnknownSiteException: try: telescope_location = KNOWN_LOCATIONS[telescope_location] except KeyError: raise ValueError( "telescope_location must be an EarthLocation or a known site, " f"got {telescope_location}" ) year, day, hour, minute = times[0, 0].to_value("yday", "date_hm").split(":") name = name.format( year=year, day=day, hour=hour, minute=minute, stem=filename.stem ) return cls( data=np.array([pant.T, pload.T, plns.T])[:, np.newaxis], time_array=times, freq_array=anc.frequencies * un.MHz, data_unit="power", loads=("ant", "internal_load", "internal_load_plus_noise_source"), auxiliary_measurements={name: anc.data[name] for name in anc.data}, filename=filename, telescope_location=telescope_location, name=name, **kw, )
[docs] @classmethod def from_file(cls, filename: str | Path, **kw) -> GSData: """Create a GSData instance from a file. This method attempts to auto-detect the file type and read it. """ filename = Path(filename) if filename.suffix == ".acq": return cls.read_acq(filename, **kw) elif filename.suffix == ".gsh5": return cls.read_gsh5(filename) else: raise ValueError("Unrecognized file type")
[docs] @classmethod def read_gsh5(cls, filename: str) -> GSData: """Reads a GSH5 file and stores the data in the GSData object.""" with h5py.File(filename, "r") as fl: data = fl["data"][:] lat, lon, alt = fl["telescope_location"][:] telescope_location = EarthLocation( lat=lat * un.deg, lon=lon * un.deg, height=alt * un.m ) times = fl["time_array"][:] if np.all(times < 24.0): time_array = Longitude(times * un.hour) else: time_array = Time(times, format="jd", location=telescope_location) freq_array = fl["freq_array"][:] * un.MHz data_unit = fl.attrs["data_unit"] objname = fl.attrs["name"] loads = fl.attrs["loads"].split("|") auxiliary_measurements = { name: fl["auxiliary_measurements"][name][:] for name in fl["auxiliary_measurements"].keys() } nsamples = fl["nsamples"][:] flg_grp = fl["flags"] flags = {} if "names" in flg_grp.attrs: flag_keys = flg_grp.attrs["names"] for name in flag_keys: flags[name] = hickle.load(fl["flags"][name]) filename = filename history = History.from_repr(fl.attrs["history"]) if "residuals" in fl: residuals = fl["residuals"][()] else: residuals = None return cls( data=data, time_array=time_array, freq_array=freq_array, data_unit=data_unit, loads=loads, auxiliary_measurements=auxiliary_measurements, filename=filename, nsamples=nsamples, flags=flags, history=history, telescope_location=telescope_location, residuals=residuals, name=objname, )
[docs] def write_gsh5(self, filename: str) -> GSData: """Writes the data in the GSData object to a GSH5 file.""" with h5py.File(filename, "w") as fl: fl["data"] = self.data fl["freq_array"] = self.freq_array.to_value("MHz") if self.in_lst: fl["time_array"] = self.time_array.hour else: fl["time_array"] = self.time_array.jd fl["telescope_location"] = np.array( [ self.telescope_location.lat.deg, self.telescope_location.lon.deg, self.telescope_location.height.to_value("m"), ] ) fl.attrs["loads"] = "|".join(self.loads) fl["nsamples"] = self.nsamples fl.attrs["effective_integration_time"] = ( self.effective_integration_time.to_value("s") ) flg_grp = fl.create_group("flags") if self.flags: flg_grp.attrs["names"] = tuple(self.flags.keys()) for name, flag in self.flags.items(): hickle.dump(flag, flg_grp.create_group(name)) fl.attrs["telescope_name"] = self.telescope_name fl.attrs["data_unit"] = self.data_unit # Now history fl.attrs["history"] = repr(self.history) fl.attrs["name"] = self.name # Data model if self.residuals is not None: fl["residuals"] = self.residuals # Now aux measurements aux_grp = fl.create_group("auxiliary_measurements") for name, meas in self.auxiliary_measurements.items(): aux_grp[name] = meas return self.update(filename=filename)
[docs] def update(self, **kwargs): """Returns a new GSData object with updated attributes.""" # If the user passes a single dictionary as history, append it. # Otherwise raise an error, unless it's not passed at all. history = kwargs.pop("history", None) if isinstance(history, Stamp): history = self.history.add(history) elif isinstance(history, dict): history = self.history.add(Stamp(**history)) elif history is not None: raise ValueError("History must be a Stamp object or dictionary") else: history = self.history return evolve(self, history=history, **kwargs)
def __add__(self, other: GSData) -> GSData: """Adds two GSData objects.""" if not isinstance(other, GSData): raise TypeError("can only add GSData objects") if self.data.shape != other.data.shape: raise ValueError("Cannot add GSData objects with different shapes") if self.auxiliary_measurements or other.auxiliary_measurements: raise ValueError("Cannot add GSData objects with auxiliary measurements") if not np.allclose(self.freq_array, other.freq_array): raise ValueError("Cannot add GSData objects with different frequencies") if self.in_lst != other.in_lst: raise ValueError("Cannot add GSData objects with different time formats") if self.in_lst: if not np.allclose(self.time_array.hour, other.time_array.hour): raise ValueError("Cannot add GSData objects with different LSTs") else: if not np.allclose(self.time_array.jd, other.time_array.jd): raise ValueError("Cannot add GSData objects with different times") # If non of the above, then we have two GSData objects at the same times and # frequencies. Adding them should just be a weighted sum. nsamples = self.flagged_nsamples + other.flagged_nsamples d1 = np.ma.masked_array(self.data, mask=self.complete_flags) d2 = np.ma.masked_array(other.data, mask=other.complete_flags) mean = (self.flagged_nsamples * d1 + other.flagged_nsamples * d2) / nsamples if self.residuals is not None and other.residuals is not None: r1 = np.ma.masked_array(self.residuals, mask=self.complete_flags) r2 = np.ma.masked_array(other.residuals, mask=other.complete_flags) resids = ( self.flagged_nsamples * r1 + other.flagged_nsamples * r2 ) / nsamples else: resids = None total_flags = GSFlag(flags=self.complete_flags & other.complete_flags) return self.update( data=mean.data, residuals=resids, nsamples=nsamples, flags={"summed_flags": total_flags}, ) @cached_property def lst_array(self) -> Longitude: """The local sidereal time array.""" if self.in_lst: return self.time_array else: return self.time_array.sidereal_time("apparent", self.telescope_location) @cached_property def lst_ranges(self) -> Longitude: """The local sidereal time array.""" if self.in_lst: return self.time_ranges else: return self.time_ranges.sidereal_time("apparent", self.telescope_location) @cached_property def gha(self) -> np.ndarray: """The GHA's of the observations.""" return Longitude(crd.lst2gha(self.lst_array.hour) * un.hourangle)
[docs] def get_moon_azel(self) -> tuple[np.ndarray, np.ndarray]: """Get the Moon's azimuth and elevation for each time in deg.""" if self.in_lst: raise ValueError( "Cannot compute Moon positions when time array is not a Time object" ) return crd.moon_azel( self.time_array[:, self.loads.index("ant")], self.telescope_location )
[docs] def get_sun_azel(self) -> tuple[np.ndarray, np.ndarray]: """Get the Sun's azimuth and elevation for each time in deg.""" if self.in_lst: raise ValueError( "Cannot compute Sun positions when time array is not a Time object" ) return crd.sun_azel( self.time_array[:, self.loads.index("ant")], self.telescope_location )
[docs] def to_lsts(self) -> GSData: """ Converts the time array to LST. Warning: this is an irreversible operation. You cannot go back to UTC after doing this. Furthermore, the auxiliary measurements will be lost. """ if self.in_lst: return self return self.update(time_array=self.lst_array, auxiliary_measurements={})
@property def in_lst(self) -> bool: """Returns True if the time array is in LST.""" return isinstance(self.time_array, Longitude) @property def nflagging_ops(self) -> int: """Returns the number of flagging operations.""" return len(self.flags)
[docs] def get_cumulative_flags( self, which_flags: tuple[str] | None = None, ignore_flags: tuple[str] = () ) -> np.ndarray: """Returns accumulated flags.""" if which_flags is None: which_flags = self.flags.keys() elif not which_flags or not self.flags: return np.zeros(self.data.shape, dtype=bool) which_flags = tuple(s for s in which_flags if s not in ignore_flags) if not which_flags: return np.zeros(self.data.shape, dtype=bool) flg = self.flags[which_flags[0]].full_rank_flags for k in which_flags[1:]: flg = flg | self.flags[k].full_rank_flags # Get into full data-shape if flg.shape != self.data.shape: flg = flg | np.zeros(self.data.shape, dtype=bool) return flg
@cached_property def complete_flags(self) -> np.ndarray: """Returns the complete flag array.""" return self.get_cumulative_flags()
[docs] def get_flagged_nsamples( self, which_flags: tuple[str] | None = None, ignore_flags: tuple[str] = () ) -> np.ndarray: """Get the nsamples of the data after accounting for flags.""" cumflags = self.get_cumulative_flags(which_flags, ignore_flags) return self.nsamples * (~cumflags).astype(int)
@cached_property def flagged_nsamples(self) -> np.ndarray: """Weights accounting for all flags.""" return self.get_flagged_nsamples()
[docs] def get_initial_yearday(self, hours: bool = False, minutes: bool = False) -> str: """Returns the year-day representation of the first time-sample in the data.""" if minutes and not hours: raise ValueError("Cannot return minutes without hours") if hours: subfmt = "date_hm" else: subfmt = "date" if self.in_lst: raise ValueError( "Cannot represent times as year-days, as the object is in LST mode" ) out = self.time_array[0, self.loads.index("ant")].to_value("yday", subfmt) if hours and not minutes: out = ":".join(out.split(":")[:-1]) return out
[docs] def add_flags( self, filt: str, flags: np.ndarray | GSFlag | Path, append_to_file: bool | None = None, ): """Append a set of flags to the object and optionally append them to file. You can always write out a *new* file, but appending flags is non-destructive, and so we allow it to be appended, in order to save disk space and I/O. """ if isinstance(flags, np.ndarray): flags = GSFlag(flags=flags, axes=("load", "pol", "time", "freq")) elif isinstance(flags, (str, Path)): flags = GSFlag.from_file(flags) flags._check_compat(self) if filt in self.flags: raise ValueError(f"Flags for filter '{filt}' already exist") new = self.update(flags={**self.flags, **{filt: flags}}) if append_to_file is None: append_to_file = new.filename is not None and new._file_appendable if append_to_file and (new.filename is None or not new._file_appendable): raise ValueError( "Cannot append to file without a filename specified on the object!" ) if append_to_file: with h5py.File(new.filename, "a") as fl: try: np.zeros(fl["data"].shape) * flags.full_rank_flags except ValueError: # Can't append to file because it would be inconsistent. return new flg_grp = fl["flags"] if "names" not in flg_grp.attrs: names_in_file = () else: names_in_file = flg_grp.attrs["names"] new_flags = tuple(k for k in new.flags.keys() if k not in names_in_file) for name in new_flags: grp = flg_grp.create_group(name) hickle.dump(new.flags[name], grp) flg_grp.attrs["names"] = tuple(new.flags.keys()) return new
[docs] def remove_flags(self, filt: str) -> GSData: """Remove flags for a given filter.""" if filt not in self.flags: raise ValueError(f"No flags for filter '{filt}'") return self.update(flags={k: v for k, v in self.flags.items() if k != filt})
[docs] def time_iter(self) -> Iterable[tuple[slice, slice, slice]]: """Returns an iterator over the time axis of data-shape arrays.""" for i in range(self.ntimes): yield (slice(None), slice(None), i, slice(None))
[docs] def load_iter(self) -> Iterable[tuple[int]]: """Returns an iterator over the load axis of data-shape arrays.""" for i in range(self.nloads): yield (i,)
[docs] def freq_iter(self) -> Iterable[tuple[slice, slice, slice]]: """Returns an iterator over the frequency axis of data-shape arrays.""" for i in range(self.nfreqs): yield (slice(None), slice(None), slice(None), i)