Source code for edges.filters.filters

"""Functions that identify and flag bad data in various ways."""

import functools
import logging
import warnings
from collections.abc import Callable, Sequence

import deprecation
import hickle
import numpy as np
from astropy import units as un
from astropy.coordinates import AltAz, Angle, SkyCoord
from astropy.time import Time
from attrs import define
from pygsdata import GSData, GSFlag, gsregister
from pygsdata.select import _mask_times, select_freqs

from edges import __version__

from .. import modeling as mdl
from .. import types as tp
from ..averaging import NsamplesStrategy, averaging, get_weights_from_strategy
from ..filters import xrfi as rfi
from .runners import run_xrfi

logger = logging.getLogger(__name__)


[docs] def gsdata_filter(multi_data: bool = False): """A decorator to register a filtering function as a potential filter. Any function that is wrapped by :func:`gsdata_filter` must implement the following signature:: def fnc( data: GSData | Sequence[GSData], use_existing_flags: bool, **kwargs ) -> GSFlag Where the ``data`` is either a single GSData object, or sequence of such objects. The return value should be a :class:`GSFlag` object, which contains the flags. Parameters ---------- multi_data Whether the filter accepts multiple objects at the same time to filter. This is *usually* so as to enable more accurate filtering when comparing different days for instance, rather than just performing a loop over the days and flagging each independently. """ def inner(func: Callable) -> Callable: func.func = func # type: ignore @functools.wraps(func) def wrapper( data: GSData | Sequence[GSData], *, flag_id: str | None = None, **kwargs, ) -> GSData | Sequence[GSData]: # Read all the data, in case they haven't been turned into objects yet. # And check that everything is the right type. if multi_data and isinstance(data, GSData): data = [data] elif not multi_data and not isinstance(data, GSData): raise TypeError( f"'{func.__name__}' only accepts single GSData objects as data." ) def per_file_processing(data: GSData, flags: GSFlag): old = np.sum(data.flagged_nsamples == 0) data = data.add_flags(flag_id or func.__name__, flags) if np.all(flags.flags): logger.warning( f"{data.name} was fully flagged during {func.__name__} filter" ) else: sz = flags.flags.size / 100 new = np.sum(flags.flags) tot = np.sum(data.flagged_nsamples == 0) totsz = data.complete_flags.size rep = data.get_initial_yearday(hours=True) logger.info( f"'{rep}': " f"{old / totsz:.2f} + {new / sz:.2f} → " f"{tot / totsz:.2f}% [bold]<+{(tot - old) / totsz:.2f}%>[/] " f"flagged after [blue]{func.__name__}[/]" ) return data this_flag = func(data=data, **kwargs) if multi_data: data = [ per_file_processing(d, out_flg) for d, out_flg in zip(data, this_flag, strict=False) ] else: data = per_file_processing(data, this_flag) return data return wrapper return inner
[docs] @gsregister("filter") @gsdata_filter() def aux_filter( *, data: GSData, minima: dict[str, float] | None = None, maxima: dict[str, float] | None = None, ) -> GSFlag: """ Perform an auxiliary filter on the object. Parameters ---------- minima Dictionary mapping auxiliary data keys to minimum allowed values. maxima Dictionary mapping auxiliary data keys to maximum allowed values. Returns ------- flags Boolean array giving which entries are bad. """ minima = minima or {} maxima = maxima or {} flags = np.zeros(data.ntimes, dtype=bool) def filt(condition, message, flags): nflags = np.sum(flags) # Sometimes, the auxiliary data will be shape (Ntimes, Nloads) # In this case, if any load is bad, all should be flagged. if condition.ndim == 2: condition = np.any(condition, axis=1) flags |= condition if nnew := np.sum(flags) - nflags: logger.info(f"{nnew}/{len(flags) - nflags} times flagged due to {message}") for k, v in minima.items(): if k not in data.auxiliary_measurements.keys(): # noqa: SIM118 raise ValueError( f"{k} not in data.auxiliary_measurements. " f"Allowed: {data.auxiliary_measurements.keys()}" ) filt(data.auxiliary_measurements[k] < v, f"{k} minimum", flags) for k, v in maxima.items(): if k not in data.auxiliary_measurements.keys(): # noqa: SIM118 raise ValueError( f"{k} not in data.auxiliary_measurements. " f"Allowed: {data.auxiliary_measurements.keys()}" ) filt(data.auxiliary_measurements[k] > v, f"{k} maximum", flags) return GSFlag(flags=flags, axes=("time",))
[docs] @gsregister("filter") @gsdata_filter() def sun_filter( *, data: GSData, elevation_range: tuple[float, float], ) -> GSFlag: """ Perform a filter based on sun position. Parameters ---------- elevation_range The minimum and maximum allowed sun elevation in degrees """ _, el = data.get_sun_azel() return GSFlag( flags=(el < elevation_range[0]) | (el > elevation_range[1]), axes=("time",) )
[docs] @gsregister("filter") @gsdata_filter() def moon_filter( *, data: GSData, elevation_range: tuple[float, float], ) -> np.ndarray: """ Perform a filter based on moon position. Parameters ---------- elevation_range The minimum and maximum allowed sun elevation. """ _, el = data.get_moon_azel() return GSFlag( flags=(el < elevation_range[0]) | (el > elevation_range[1]), axes=("time",) )
[docs] @gsregister("filter") @gsdata_filter() def sky_coord_filter( *, data: GSData, coord: str | SkyCoord, elevation_range: tuple[Angle, Angle], ) -> GSFlag: """ Perform a filter based on a sky coordinate position. Parameters ---------- coord The sky coordinate to filter on. elevation_range The minimum and maximum allowed elevation (as an astropy Angle). """ if isinstance(coord, str): coord = SkyCoord.from_name(coord) # Use the times of the first load, assuming that this is the antenna data. azalt = coord.transform_to( AltAz(location=data.telescope.location, obstime=Time(data.times[:, 0])) ) alt = azalt.alt return GSFlag( flags=(alt < elevation_range[0]) | (alt > elevation_range[1]), axes=("time",) )
[docs] @gsregister("filter") @gsdata_filter() def galaxy_filter( *, data: GSData, elevation_range: tuple[Angle, Angle] = (-90 * un.deg, 0 * un.deg), ) -> GSFlag: """ Perform a filter based on the Galactic center position. Parameters ---------- elevation_range The minimum and maximum allowed elevation (as an astropy Angle). """ return sky_coord_filter.func( data=data, coord="Galactic Center", elevation_range=elevation_range )
@define(frozen=False, slots=False) class _RFIFilterFactory: method: str @property def __name__(self): return f"rfi_{self.method}_filter" @property def __docstring__(self): return getattr(rfi, self.method).__doc__ def __call__( self, data: GSData, *, n_threads: int = 1, freq_range: tuple[float, float] = (40, 200), nsamples_strategy: NsamplesStrategy = NsamplesStrategy.FLAGGED_NSAMPLES, **kwargs, ): mask = (data.freqs.to_value("MHz") >= freq_range[0]) & ( data.freqs.to_value("MHz") <= freq_range[1] ) flags = data.complete_flags wgt, _ = get_weights_from_strategy(data, nsamples_strategy) out_flags = run_xrfi( method=self.method, spectrum=data.data[..., mask], freqs=data.freqs[mask].to_value("MHz"), weights=wgt[..., mask], n_threads=n_threads, **kwargs, ) out = np.zeros_like(flags) out[..., mask] = out_flags return GSFlag( flags=out, axes=("load", "pol", "time", "freq")[-out.ndim :], ) rfi_iterative_filter = gsregister("filter")( gsdata_filter()(_RFIFilterFactory("iterative")) ) rfi_model_filter = rfi_iterative_filter # Backwards compatibility rfi_watershed_filter = gsregister("filter")( gsdata_filter()(_RFIFilterFactory("watershed")) ) rfi_iterative_sliding_window = gsregister("filter")( gsdata_filter()(_RFIFilterFactory("iterative_sliding_window")) )
[docs] @gsregister("filter") @gsdata_filter() def apply_flags(*, data: GSData, flags: tp.PathLike | GSFlag): """Apply flags from a file.""" if not isinstance(flags, GSFlag): flags = hickle.load(flags) return flags
[docs] @gsregister("filter") @gsdata_filter() def flag_frequency_ranges( *, data: GSData, freq_ranges: list[tuple[float, float]], invert: bool = False ): """Flag explicit frequency ranges. Parameters ---------- data The data to flag. freq_ranges A list of tuples, each containing the start and end of a frequency range to flag in MHz. invert If True, invert the flagging (i.e. only *keep* the data inside the ranges given). """ if invert: flags = np.ones(data.nfreqs, dtype=bool) else: flags = np.zeros(data.nfreqs, dtype=bool) fmhz = data.freqs.to_value("MHz") for fmin, fmax in freq_ranges: if invert: flags[(fmhz >= fmin) & (fmhz < fmax)] = False else: flags |= (fmhz >= fmin) & (fmhz < fmax) return GSFlag( flags=flags, axes=("freq",), )
[docs] @gsregister("filter") @gsdata_filter() def negative_power_filter(*, data: GSData): """Filter out integrations that have *any* negative/zero power. These integrations obviously have some weird stuff going on. """ flags = np.any(data.data < 0, axis=(0, 1, 3)) return GSFlag(flags=flags, axes=("time",))
def _peak_power_filter( *, data: GSData, threshold: float = 40.0, peak_freq_range: tuple[float, float] = (80, 200), mean_freq_range: tuple[float, float] | None = None, ): """ Filter out whole integrations that have high power in a given frequency range. Parameters ---------- threshold This is the threshold beyond which the peak power causes the integration to be flagged. The units of the threhsold are 10*log10(peak_power / mean), where the mean is the mean power of spectrum in the same frequency range (omitting power spikes > peak_power/10) peak_freq_range The range of frequencies over which to search for the peak. mean_freq_range The range of frequencies over which to take a mean to compare to the peak. By default, the same as the ``peak_freq_range``. """ if peak_freq_range[0] >= peak_freq_range[1]: raise ValueError( f"The frequency range of the peak must be non-zero, got {peak_freq_range}" ) if mean_freq_range is not None and mean_freq_range[0] >= mean_freq_range[1]: raise ValueError( "The freq range of the mean must be a tuple with first less than second " f"value, got {mean_freq_range}" ) freqs = data.freqs.to_value("MHz") mask = (freqs > peak_freq_range[0]) & (freqs <= peak_freq_range[1]) if not np.any(mask): return np.zeros(shape=(data.nloads, data.npols, data.ntimes), dtype=bool) spec = data.data[..., mask] peak_power = spec.max(axis=-1) if mean_freq_range is not None: mask = (freqs > mean_freq_range[0]) & (freqs <= mean_freq_range[1]) if not np.any(mask): return np.zeros(data.ntimes, dtype=bool) spec = data.data[..., mask] mean, _ = averaging.weighted_mean( spec, weights=( (spec > 0) & ((spec.transpose(0, 1, 3, 2) < peak_power / 10).transpose(0, 1, 3, 2)) ).astype(float), axis=-1, ) peak_power = 10 * np.log10(peak_power / mean) return peak_power > threshold
[docs] @gsregister("filter") @gsdata_filter() def peak_power_filter( *, data: GSData, threshold: float = 40.0, peak_freq_range: tuple[float, float] = (80, 200), mean_freq_range: tuple[float, float] | None = None, ): """ Filter out whole integrations that have high power > 80 MHz. Parameters ---------- threshold This is the threshold beyond which the peak power causes the integration to be flagged. The units of the threhsold are 10*log10(peak_power / mean), where the mean is the mean power of spectrum in the same frequency range (omitting power spikes > peak_power/10) peak_freq_range The range of frequencies over which to search for the peak. mean_freq_range The range of frequencies over which to take a mean to compare to the peak. By default, the same as the ``peak_freq_range``. """ flags = _peak_power_filter( data=data, threshold=threshold, peak_freq_range=peak_freq_range, mean_freq_range=mean_freq_range, ) return GSFlag( flags=flags, axes=( "load", "pol", "time", )[-flags.ndim :], )
[docs] @gsregister("filter") @gsdata_filter() def peak_orbcomm_filter( *, data: GSData, threshold: float = 40.0, mean_freq_range: tuple[float, float] | None = (80, 200), ): """ Filter out whole integrations that have high power between (137, 138) MHz. Parameters ---------- threshold This is the threshold beyond which the peak power causes the integration to be flagged. The units of the threhsold are 10*log10(peak_power / mean), where the mean is the mean power of spectrum in the ``mean_freq_range`` (omitting power spikes > peak_power/10) mean_freq_range The range of frequencies over which to take a mean to compare to the peak. By default, the same as the ``peak_freq_range``. """ flags = _peak_power_filter( data=data, threshold=threshold, peak_freq_range=(137.0, 138.0), mean_freq_range=mean_freq_range, ) return GSFlag( flags=flags, axes=( "load", "pol", "time", )[-flags.ndim :], )
[docs] @gsregister("filter") @gsdata_filter() def single_channel_spike_filter( *, data: GSData, threshold: float = 200, freq_range: tuple[tp.FreqType, tp.FreqType] = (88 * un.MHz, 120 * un.MHz), ): """Filter data based on single channel spikes. This filter detrends the data using a simple convolution kernel with weights [0.5, 0, 0.5], which makes single channel spikes stand out. The entire spectrum is flagged if the residual of the original spectrum to the de-trended is larger than the threshold. """ freqs = data.freqs mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1]) if not np.any(mask): return GSFlag(flags=np.zeros(data.ntimes, dtype=bool), axes=("time",)) power = data.data[..., mask] avg = (power[..., 2:] + power[..., :-2]) / 2 deviation_power = np.abs(power[..., 1:-1] - avg) return GSFlag( flags=np.max(deviation_power, axis=-1) > threshold, axes=("load", "pol", "time"), )
[docs] @gsregister("filter") @gsdata_filter() def maxfm_filter(*, data: GSData, threshold: float = 200): """Filter data based on large single-channel spikes in FM band. This function is only provided as a convenience when comparing to the legacy code that had the same filter with this name. It is really just a very thin wrapper around `single_channel_spike_filter`, focusing on the FM band. """ return single_channel_spike_filter.func( data=data, threshold=threshold, freq_range=(88 * un.MHz, 120 * un.MHz) )
[docs] @gsregister("filter") @gsdata_filter() def filter_150mhz(*, data: GSData, threshold: float): """Filter data based on power around 150 MHz. This takes the RMS of the power around 153.5 MHz (in a 1.5 MHz bin), after subtracting the mean, then compares this to the mean power of a 1.5 MHz bin around 157 MHz (which is expected to be cleaner). If this ratio (RMS to mean) is greater than 200 times the threshold given, the integration will be flagged. """ if data.freqs.max() < 157 * un.MHz: return GSFlag(flags=np.zeros(data.ntimes, dtype=bool), axes=("time",)) freq_mask = (data.freqs >= 152.75 * un.MHz) & (data.freqs <= 154.25 * un.MHz) mean = np.mean(data.data[..., freq_mask], axis=-1) rms = np.sqrt(np.mean((data.data[..., freq_mask].T - mean.T) ** 2)).T freq_mask2 = (data.freqs >= 156.25 * un.MHz) & (data.freqs <= 157.75 * un.MHz) av = np.mean(data.data[..., freq_mask2], axis=-1) d = 200.0 * np.sqrt(rms) / av return GSFlag( flags=d > threshold, axes=( "load", "pol", "time", ), )
[docs] @gsregister("filter") @gsdata_filter() def power_percent_filter( *, data: GSData, freq_range: tuple[float, float] = (100, 200), min_threshold: float = 0, max_threshold: float = 3, ): """Filter data based on the ratio of power in a band compared to entire dataset. This filter computes the sum of power from the input connected to the antenna within a given band, and finds the ratio within that band compared to the entired dataset. If that ratio is outside the thresholds given for a given timestamp, then the entired integration is flagged. Note: this is a very bespoke filter. Thresholds that make sense will depend on both the ``freq_range`` given, and the frequency range of the data itself. In this regard, it is very flexible, but care must be taken to set the parameters appropriately. Parameters ---------- data : GSData The data to be flagged. freq_range : tuple[float, float] The frequency range of the power to be summed in the numerator, in MHz. min_threshold : float Threshold of the ratio below which the integration will be flagged. max_threshold : float Threshold of the ratio above which the integration will be flagged. """ if data.data_unit != "power" or data.nloads != 3 or "ant" not in data.loads: raise ValueError("Cannot perform power percent filter on non-power data!") p0 = data.data[data.loads.index("ant")] freqs = data.freqs.to_value("MHz") mask = (freqs > freq_range[0]) & (freqs <= freq_range[1]) if not np.any(mask): return GSFlag(flags=np.zeros(data.ntimes, dtype=bool), axes=("time",)) ppercent = 100 * np.sum(p0[..., mask], axis=-1) / np.sum(p0, axis=-1) return GSFlag( flags=(ppercent < min_threshold) | (ppercent > max_threshold), axes=( "pol", "time", ), )
[docs] @gsregister("filter") @gsdata_filter() def rms_filter( data: GSData, threshold: float, freq_range: tuple[tp.FreqType, tp.FreqType] = (0.0 * un.MHz, np.inf * un.MHz), nsamples_strategy: NsamplesStrategy = NsamplesStrategy.FLAGGED_NSAMPLES, model: mdl.Model | None = None, ) -> bool: """Filter integrations based on the rms of the residuals. Parameters ---------- data The data to be filtered. threshold The threshold at which to flag integrations. freq_range The frequency range to use in calculating the RMS. nsamples_strategy The strategy to use to infer weights for computing RMS. model A model to be used to fit each integration. Not required if a model already exists on the data. """ from ..analysis.datamodel import add_model data = select_freqs(data, freq_range=freq_range) if data.data.size == 0: # No data in the given frequency range, so nothing to flag. return GSFlag( flags=np.zeros(shape=(data.nloads, data.npols, data.ntimes), dtype=bool), axes=("load", "pol", "time"), ) if data.residuals is None: if model is None: raise ValueError("Cannot perform rms_filter without residuals or a model.") data = add_model(data=data, model=model, nsamples_strategy=nsamples_strategy) w = get_weights_from_strategy(data, nsamples_strategy)[0] rms = np.sqrt(averaging.weighted_mean(data=data.residuals**2, weights=w)[0]) return GSFlag( flags=(rms > threshold), axes=( "load", "pol", "time", ), )
[docs] @gsregister("filter") @gsdata_filter() @deprecation.deprecated( deprecated_in="8.1.0", removed_in="9.0.0", current_version=__version__, details="Use the rms_filter function instead", ) def rmsf_filter( *, data: GSData, threshold: float = 200, freq_range: tuple[float, float] = (60, 80), ) -> GSFlag: """ Filter data based on rms calculated between 60 and 80 MHz. Note that this function is deprecated in favour of the more general :func:`rms_filter`. """ warnings.warn( "rmsf_filter is deprecated, please use rms_filter instead.", DeprecationWarning, stacklevel=2, ) if data.data_unit not in ("uncalibrated_temp", "temperature"): raise ValueError( "Unsupported data_unit for rmsf_filter. " "Need temperature or uncalibrated_temp" ) return rms_filter.func( data=data, threshold=threshold, freq_range=freq_range, nsamples_strategy=NsamplesStrategy.FLAGGED_NSAMPLES, model=mdl.LinLog(n_terms=1, beta=-2.5), )
[docs] @gsdata_filter() def rms_rfi_filter( data: GSData, threshold: float = 3.0, nsamples_strategy: NsamplesStrategy = NsamplesStrategy.FLAGGED_NSAMPLES, ) -> GSFlag: """Flag specific channel-integrations via their outlier-ness compared to RMS.""" w = get_weights_from_strategy(data, nsamples_strategy)[0] rms = np.sqrt(np.average(np.square(data.residuals), weights=w, axis=-1))[..., None] return GSFlag( flags=data.residuals > rms * threshold, axes=("load", "pol", "time", "freq") )
[docs] @gsregister("filter") @gsdata_filter() def explicit_day_filter( data: GSData, flag_days: list[tuple[int, int] | tuple[int, int, int] | int | Time], ) -> GSFlag: """Filter out any data coming from specific days. Parameters ---------- flag_days A list of days to flag. Each entry can be a 2-tuple, 3-tuple, astropy.Time or an int. If a 2-tuple, it is interpreted as ``(year, day_of_year)``. If a 3-tuple, it is interpreted as ``(year, month, day)``. If an int, it is interpreted as a Julian day. """ for i, day in enumerate(flag_days): if hasattr(day, "__len__"): if len(day) == 2: t = Time(f"{day[0]:04}:{day[1]:03}:00:00:00.000", format="yday") elif len(day) == 3: t = Time( f"{day[0]:04}-{day[1]:02}-{day[2]:02} 00:00:00.000", format="iso" ) else: raise ValueError("Day must be a 2-tuple, 3-tuple, Time or an int.") flag_days[i] = int(t.jd) elif isinstance(day, Time): flag_days[i] = int(day.jd) if not all(isinstance(day, int) for day in flag_days): raise ValueError("All entries in flag_days must be integers.") return GSFlag( flags=np.any(np.isin(data.times.jd.astype(int), flag_days), axis=-1), axes=("time",), )
[docs] @gsregister("reduce") def prune_flagged_integrations(data: GSData, **kwargs) -> GSData: """Remove integrations that are flagged for all freq-pol-loads.""" flg = np.all(data.complete_flags, axis=(0, 1, 3)) return _mask_times(data, ~flg)