Source code for edges_analysis.filters.filters

"""Functions that identify and flag bad data in various ways."""

from __future__ import annotations

import functools
import hickle
import logging
import numpy as np
import yaml
from astropy import units as u
from attrs import define
from edges_cal import modelling as mdl
from edges_cal import types as tp
from edges_cal import xrfi as rfi
from edges_cal.xrfi import ModelFilterInfoContainer, model_filter
from pathlib import Path
from typing import Callable, Sequence

from .. import tools
from ..averaging import averaging
from ..data import DATA_PATH
from ..datamodel import add_model
from ..gsdata import GSData, GSFlag, gsregister

logger = logging.getLogger(__name__)


class _GSDataFilter:
    def __init__(
        self,
        func: Callable,
        multi_data: bool = False,
    ):
        self.func = func
        self.multi_data = multi_data
        functools.update_wrapper(self, func, updated=())

    def __call__(
        self,
        data: Sequence[tp.PathLike | GSData],
        *,
        write: bool | None = None,
        flag_id: str = None,
        **kwargs,
    ) -> GSData | Sequence[GSData]:
        # Read all the data, in case they haven't been turned into objects yet.
        # And check that everything is the right type.
        if isinstance(data, (Path, str)):
            data = GSData.from_file(data)

        if self.multi_data and isinstance(data, (GSData, Path, str)):
            data = [data if isinstance(data, GSData) else GSData.from_file(data)]
        elif not self.multi_data and not isinstance(data, GSData):
            raise TypeError(
                f"'{self.func.__name__}' only accepts single GSData objects as data."
            )

        def per_file_processing(data: GSData, flags: GSFlag):
            old = np.sum(data.flagged_nsamples == 0)

            data = data.add_flags(
                flag_id or self.func.__name__, flags, append_to_file=write
            )

            if np.all(flags.flags):
                if data.in_lst:
                    name = data.filename
                else:
                    name = data.get_initial_yearday(hours=True)

                logger.warning(
                    f"{name} was fully flagged during {self.func.__name__} filter"
                )
            else:
                sz = flags.flags.size / 100
                new = np.sum(flags.flags)
                tot = np.sum(data.flagged_nsamples == 0)
                totsz = data.complete_flags.size

                if not data.in_lst:
                    rep = data.get_initial_yearday(hours=True)
                elif data.filename:
                    rep = data.filename
                else:
                    rep = "unknown"

                logger.info(
                    f"'{rep}': "
                    f"{old / totsz:.2f} + {new / sz:.2f} → "
                    f"{tot / totsz:.2f}% [bold]<+{(tot - old) / totsz:.2f}%>[/] "
                    f"flagged after [blue]{self.func.__name__}[/]"
                )

            return data

        this_flag = self.func(data=data, **kwargs)

        if self.multi_data:
            data = [
                per_file_processing(d, out_flg) for d, out_flg in zip(data, this_flag)
            ]
        else:
            data = per_file_processing(data, this_flag)

        return data


[docs] @define class gsdata_filter: # noqa: N801 """A decorator to register a filtering function as a potential filter. Any function that is wrapped by :func:`gsdata_filter` must implement the following signature:: def fnc( data: GSData | Sequence[GSData], use_existing_flags: bool, **kwargs ) -> GSFlag Where the ``data`` is either a single GSData object, or sequence of such objects. The return value should be a :class:`GSFlag` object, which contains the flags. Parameters ---------- multi_data Whether the filter accepts multiple objects at the same time to filter. This is *usually* so as to enable more accurate filtering when comparing different days for instance, rather than just performing a loop over the days and flagging each independently. """ multi_data: bool = False def __call__(self, func: Callable) -> Callable: """Wrap the function in a GSDataFilter instance.""" return _GSDataFilter(func, self.multi_data)
[docs] def chunked_iterative_model_filter( *, x: np.ndarray, data: np.ndarray, flags: np.ndarray | None = None, init_flags: np.ndarray | None = None, chunk_size: float = np.inf, **kwargs, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Perform a chunk-wise iterative model filter. This breaks the given data into smaller chunks and then calls :func:`edges_cal.xrfi.model_filter` on each chunk, returning the full 1D array of flags after all the chunks have been processed. Parameters ---------- chunk_size The size of the chunks to process, in units of the input coordinates, ``x``. **kwargs Everything else is passed to :func:`edges_cal.xrfi.model_filter`. Returns ------- flags The 1D array of flags corresponding to the data. Note that input flags are not modified in the course of this function, but the output does already contain those flags. resid Residuals to the model std Estimates of the standard deviation of the data at each data point. """ if flags is None: flags = np.zeros(len(x), dtype=bool) if init_flags is None: init_flags = np.zeros(len(x), dtype=bool) out_flags = flags | np.isnan(data) resids = np.zeros_like(data) std = np.zeros_like(data) xmin = x.min() infos = ModelFilterInfoContainer() while xmin < x.max(): mask = (x >= xmin) & (x < xmin + chunk_size) out_flags[mask], info = model_filter( x=x[mask], data=data[mask], flags=out_flags[mask], init_flags=init_flags[mask], **kwargs, ) resids[mask] = info.get_residual() std[mask] = info.stds[-1] infos = infos.append(info) xmin += chunk_size return out_flags, resids, std, infos
[docs] def explicit_filter(times, bad, ret_times=False): """ Explicitly filter out certain times. Parameters ---------- times : array-like The input times. This can be either a recarray, a list of tuples, a list of ints, or a 2D array of ints. The columns of the recarray (or the entries of the tuples) should correspond to `year`, 'day` and `hour`. The last two are not required, eg. 2-tuples will be interpreted as ``(year, hour)``, and a list of ints will be interpreted as just years. bad : str or array-like Like `times`, but specifying the bad entries. Need not have the same columns as `times`. If any bad exists within a given time frame, it will be considered bad. Likewise, if bad has higher scope than times, then it will also be bad. Eg.: ``times = [2018], bad=[(2018, 125)]``, times will be considered bad. Also, ``times=[(2018, 125)], bad=[2018]``, times will be considered bad. If a str, reads the bad times from a properly configured YAML file. ret_times : bool, optional If True, return the good times as well as the indices of such in original array. Returns ------- keep : indices marking which times are not bad if inplace=False. times : Only if `ret_times=True`. An array of the times that are good. """ if isinstance(bad, str): with open(bad) as fl: bad = yaml.load(fl, Loader=yaml.FullLoader)["bad_days"] try: nt = len(times[0]) except AttributeError: nt = 1 try: nb = len(bad[0]) except AttributeError: nb = 1 assert nt in {1, 2, 3}, "times must be an array of 1,2 or 3-tuples" assert nb in {1, 2, 3}, "bad must be an array of 1,2 or 3-tuples" if nt < nb: bad = {b[:nt] for b in bad} nb = nt keep = [t[:nb] not in bad for t in times] return (keep, times[keep]) if ret_times else keep
@gsregister("filter") @gsdata_filter() def aux_filter( *, data: GSData, minima: dict[str, float] | None = None, maxima: dict[str, float] | None = None, ) -> GSFlag: """ Perform an auxiliary filter on the object. Parameters ---------- minima Dictionary mapping auxiliary data keys to minimum allowed values. maxima Dictionary mapping auxiliary data keys to maximum allowed values. Returns ------- flags Boolean array giving which entries are bad. """ minima = minima or {} maxima = maxima or {} flags = np.zeros(len(data.time_array), dtype=bool) def filt(condition, message, flags): nflags = np.sum(flags) # Sometimes, the auxiliary data will be shape (Ntimes, Nloads) # In this case, if any load is bad, all should be flagged. if condition.ndim == 2: condition = np.any(condition, axis=1) flags |= condition if nnew := np.sum(flags) - nflags: logger.info(f"{nnew}/{len(flags) - nflags} times flagged due to {message}") for k, v in minima.items(): if k not in data.auxiliary_measurements: raise ValueError( f"{k} not in data.auxiliary_measurements. " f"Allowed: {data.auxiliary_measurements.keys()}" ) filt(data.auxiliary_measurements[k] < v, f"{k} minimum", flags) for k, v in maxima.items(): if k not in data.auxiliary_measurements: raise ValueError( f"{k} not in data.auxiliary_measurements. " f"Allowed: {data.auxiliary_measurements.keys()}" ) filt(data.auxiliary_measurements[k] > v, f"{k} maximum", flags) return GSFlag(flags=flags, axes=("time",)) @gsregister("filter") @gsdata_filter() def sun_filter( *, data: GSData, elevation_range: tuple[float, float], ) -> GSFlag: """ Perform a filter based on sun position. Parameters ---------- elevation_range The minimum and maximum allowed sun elevation in degrees """ _, el = data.get_sun_azel() return GSFlag( flags=(el < elevation_range[0]) | (el > elevation_range[1]), axes=("time",) ) @gsregister("filter") @gsdata_filter() def moon_filter( *, data: GSData, elevation_range: tuple[float, float], ) -> np.ndarray: """ Perform a filter based on sun position. Parameters ---------- elevation_range The minimum and maximum allowed sun elevation. """ _, el = data.get_moon_azel() return GSFlag( flags=(el < elevation_range[0]) | (el > elevation_range[1]), axes=("time",) ) @define class _RFIFilterFactory: method: str @property def __name__(self): return f"rfi_{self.method}_filter" @property def __docstring__(self): return getattr(rfi, self.method).__doc__ def __call__( self, data: GSData, *, n_threads: int = 1, freq_range: tuple[float, float] = (40, 200), **kwargs, ): mask = (data.freq_array.to_value("MHz") >= freq_range[0]) & ( data.freq_array.to_value("MHz") <= freq_range[1] ) flags = data.complete_flags out_flags = tools.run_xrfi( method=self.method, spectrum=data.data[..., mask], freq=data.freq_array[mask].to_value("MHz"), flags=flags[..., mask], weights=data.nsamples[..., mask], n_threads=n_threads, **kwargs, ) out = np.zeros_like(flags) out[..., mask] = out_flags return GSFlag( flags=out, axes=("load", "pol", "time", "freq")[-out.ndim :], ) rfi_model_filter = gsregister("filter")(gsdata_filter()(_RFIFilterFactory("model"))) rfi_model_sweep_filter = gsregister("filter")( gsdata_filter()(_RFIFilterFactory("model_sweep")) ) rfi_watershed_filter = gsregister("filter")( gsdata_filter()(_RFIFilterFactory("watershed")) ) rfi_model_nonlinear_window_filter = gsregister("filter")( gsdata_filter()(_RFIFilterFactory("model_nonlinear_window")) ) @gsregister("filter") @gsdata_filter() def apply_flags(*, data: GSData, flags: tp.PathLike | GSFlag): """Apply flags from a file.""" if not isinstance(flags, GSFlag): flags = hickle.load(flags) return flags @gsregister("filter") @gsdata_filter() def rfi_explicit_filter(*, data: GSData, file: tp.PathLike | None = None): """A filter of explicit channels of RFI.""" if file is None: file = DATA_PATH / "known_rfi_channels.yaml" return GSFlag( flags=rfi.xrfi_explicit( data.freq_array, rfi_file=file, ), axes=("freq",), ) @gsregister("filter") @gsdata_filter() def flag_frequency_ranges( *, data: GSData, freq_ranges: list[tuple[float, float]], invert: bool = False ): """Flag frequency ranges.""" if invert: flags = np.ones(data.nfreqs, dtype=bool) else: flags = np.zeros(data.nfreqs, dtype=bool) fmhz = data.freq_array.to_value("MHz") for fmin, fmax in freq_ranges: if invert: flags[(fmhz >= fmin) & (fmhz < fmax)] = False else: flags |= fmhz >= fmin flags |= fmhz < fmax return GSFlag( flags=flags, axes=("freq",), ) @gsregister("filter") @gsdata_filter() def negative_power_filter(*, data: GSData): """Filter out integrations that have *any* negative/zero power. These integrations obviously have some weird stuff going on. """ flags = np.array([np.any(data.data[slc] <= 0) for slc in data.time_iter()]) return GSFlag(flags=flags, axes=("time",)) def _peak_power_filter( *, data: GSData, threshold: float = 40.0, peak_freq_range: tuple[float, float] = (80, 200), mean_freq_range: tuple[float, float] | None = None, ): """ Filters out whole integrations that have high power > 80 MHz. Parameters ---------- threshold This is the threshold beyond which the peak power causes the integration to be flagged. The units of the threhsold are 10*log10(peak_power / mean), where the mean is the mean power of spectrum in the same frequency range (omitting power spikes > peak_power/10) peak_freq_range The range of frequencies over which to search for the peak. mean_freq_range The range of frequencies over which to take a mean to compare to the peak. By default, the same as the ``peak_freq_range``. """ if peak_freq_range[0] >= peak_freq_range[1]: raise ValueError( f"The frequency range of the peak must be non-zero, got {peak_freq_range}" ) if mean_freq_range is not None and mean_freq_range[0] >= mean_freq_range[1]: raise ValueError( "The freq range of the mean must be a tuple with first less than second " f"value, got {mean_freq_range}" ) freqs = data.freq_array.to_value("MHz") mask = (freqs > peak_freq_range[0]) & (freqs <= peak_freq_range[1]) if not np.any(mask): return np.zeros(data.ntimes, dtype=bool) spec = data.data[..., mask] peak_power = spec.max(axis=-1) if mean_freq_range is not None: mask = (freqs > mean_freq_range[0]) & (freqs <= mean_freq_range[1]) if not np.any(mask): return np.zeros(data.ntimes, dtype=bool) spec = data.data[..., mask] mean, _ = averaging.weighted_mean( spec, weights=( (spec > 0) & ((spec.transpose(0, 1, 3, 2) < peak_power / 10).transpose(0, 1, 3, 2)) ).astype(float), axis=-1, ) peak_power = 10 * np.log10(peak_power / mean) return peak_power > threshold @gsregister("filter") @gsdata_filter() def peak_power_filter( *, data: GSData, threshold: float = 40.0, peak_freq_range: tuple[float, float] = (80, 200), mean_freq_range: tuple[float, float] | None = None, ): """ Filters out whole integrations that have high power > 80 MHz. Parameters ---------- threshold This is the threshold beyond which the peak power causes the integration to be flagged. The units of the threhsold are 10*log10(peak_power / mean), where the mean is the mean power of spectrum in the same frequency range (omitting power spikes > peak_power/10) peak_freq_range The range of frequencies over which to search for the peak. mean_freq_range The range of frequencies over which to take a mean to compare to the peak. By default, the same as the ``peak_freq_range``. """ flags = _peak_power_filter( data=data, threshold=threshold, peak_freq_range=peak_freq_range, mean_freq_range=mean_freq_range, ) return GSFlag( flags=flags, axes=( "load", "pol", "time", ), ) @gsregister("filter") @gsdata_filter() def peak_orbcomm_filter( *, data: GSData, threshold: float = 40.0, mean_freq_range: tuple[float, float] | None = (80, 200), ): """ Filters out whole integrations that have high power between (137, 138) MHz. Parameters ---------- threshold This is the threshold beyond which the peak power causes the integration to be flagged. The units of the threhsold are 10*log10(peak_power / mean), where the mean is the mean power of spectrum in the ``mean_freq_range`` (omitting power spikes > peak_power/10) mean_freq_range The range of frequencies over which to take a mean to compare to the peak. By default, the same as the ``peak_freq_range``. """ flags = _peak_power_filter( data=data, threshold=threshold, peak_freq_range=(137.0, 138.0), mean_freq_range=mean_freq_range, ) return GSFlag( flags=flags, axes=( "load", "pol", "time", )[-flags.ndim :], ) @gsregister("filter") @gsdata_filter() def maxfm_filter(*, data: GSData, threshold: float = 200): """Max FM power filter. This takes power of the spectrum between 80 MHz and 120 MHz(the fm range). In that range, it checks each frequency bin to the estimated values using the mean from the side bins. It then takes the max of all the all values that exceeded its expected value (from mean). Compares the max exceeded power with the threshold and if it is greater than the threshold given, the integration will be flagged. """ freqs = data.freq_array.to_value("MHz") fm_freq = (freqs >= 88) & (freqs <= 120) # freq mask between 80 and 120 MHz for the FM range if not np.any(fm_freq): return GSData(flags=np.zeros(data.ntimes, dtype=bool), axes=("time",)) fm_power = data.data[..., fm_freq] avg = (fm_power[..., 2:] + fm_power[..., :-2]) / 2 fm_deviation_power = np.abs(fm_power[..., 1:-1] - avg) maxfm = np.max(fm_deviation_power, axis=-1) return GSFlag( flags=maxfm > threshold, axes=("load", "pol", "time")[-maxfm.ndim :], ) @gsregister("filter") @gsdata_filter() def rmsf_filter( *, data: GSData, threshold: float = 200, freq_range: tuple[float, float] = (60, 80), tload: float = 1000, tcal: float = 300, ): """ Rmsf filter - filters out based on rms calculated between 60 and 80 MHz. An initial powerlaw model is calculated using the normalized frequency range. Data between the freq_range is clipped. A standard deviation is calculated using the data and the init_model. Then rms is calculated from the mean that is eatimated using the standard deviation times initmodel. """ freqs = data.freq_array.to_value("MHz") freq_mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1]) if not np.any(freq_mask): return np.zeros(data.ntimes, dtype=bool) if data.data_unit == "uncalibrated": spec = (data.data * tload) + tcal elif data.data_unit in ("uncalibrated_temp", "temperature"): spec = data.data else: raise ValueError( "Unsupported data_unit for rmsf_filter. " "Need uncalibrated or uncalibrated_temp" ) freq = data.freq_array.value[freq_mask] init_model = (freq / 75.0) ** -2.5 spec = spec[..., freq_mask] T75 = np.sum(init_model * spec, axis=-1) / np.sum(init_model**2) prod = np.outer(T75, init_model) # We have to set the shape explicitly, because the outer product collapses # the dimensions. prod.shape = spec.shape rms = np.sqrt(np.mean((spec - prod) ** 2, axis=-1)) return GSFlag( flags=rms > threshold, axes=( "load", "pol", "time", ), ) @gsregister("filter") @gsdata_filter() def filter_150mhz(*, data: GSData, threshold: float): """Filter based on power around 150 MHz. This takes the RMS of the power around 153.5 MHz (in a 1.5 MHz bin), after subtracting the mean, then compares this to the mean power of a 1.5 MHz bin around 157 MHz (which is expected to be cleaner). If this ratio (RMS to mean) is greater than 200 times the threshold given, the integration will be flagged. """ if data.freq_array.max() < 157 * u.MHz: return GSFlag(flags=np.zeros(data.ntimes, dtype=bool), axes=("time",)) freq_mask = (data.freq_array >= 152.75 * u.MHz) & ( data.freq_array <= 154.25 * u.MHz ) mean = np.mean(data.data[..., freq_mask], axis=-1) rms = np.sqrt(np.mean((data.data[..., freq_mask] - mean) ** 2)) freq_mask2 = (data.freq_array >= 156.25 * u.MHz) & ( data.freq_array <= 157.75 * u.MHz ) av = np.mean(data.spectrum[..., freq_mask2], axis=-1) d = 200.0 * np.sqrt(rms) / av return GSFlag( flags=d > threshold, axes=( "load", "pol", "time", ), ) @gsregister("filter") @gsdata_filter() def power_percent_filter( *, data: GSData, freq_range: tuple[float, float] = (100, 200), min_threshold: float = -0.7, max_threshold: float = 3, ): """Filter for the power above 100 MHz seen in swpos 0. Calculates the percentage of power between 100 and 200 MHz & when the switch is in position 0. And flags integrations if the percentage is above or below the given threshold. """ if data.data_unit != "power" or data.nloads != 3 or "ant" not in data.loads: raise ValueError("Cannot perform power percent filter on non-power data!") p0 = data.data[data.loads.index("ant")] freqs = data.freq_array.to_value("MHz") mask = (freqs > freq_range[0]) & (freqs <= freq_range[1]) if not np.any(mask): return GSFlag(flags=np.zeros(data.ntimes, dtype=bool), axes=("time",)) ppercent = 100 * np.sum(p0[..., mask], axis=-1) / np.sum(p0, axis=-1) return GSFlag( flags=(ppercent < min_threshold) | (ppercent > max_threshold), axes=( "pol", "time", ), ) @gsregister("filter") @gsdata_filter() def object_rms_filter( data: GSData, rms_threshold: float, f_low: float = 0.0, f_high: float = np.inf, weighted: bool = False, flagged: bool = True, model: mdl.Model | None = None, ) -> bool: """Filter integrations based on the rms of the residuals.""" if data.ntimes > 1: raise ValueError( "The object_rms_filter is meant to be performed on lst-averaged data" ) data = flag_frequency_ranges( data=data, freq_ranges=[(f_low, f_high)], invert=True, write=False ) if data.residuals is None: if model is None: raise ValueError( "Cannot perform object rms filter without residuals or a model." ) data = add_model(data=data, model=model) if weighted: rms = np.sqrt( averaging.weighted_mean( data=data.residuals**2, weights=data.flagged_nsamples )[0] ) elif flagged: flags = data.flagged_nsamples == 0 rms = np.sqrt(np.mean(data.residuals[~flags] ** 2)) else: rms = np.sqrt(np.mean(data.residuals**2)) logger.info(f"RMS for {data.name}: {rms:.2f} mK") return GSFlag( flags=np.array([rms > rms_threshold]), axes=("time",), )