"""Various utility functions."""
from __future__ import annotations
import logging
import numpy as np
import warnings
from collections import defaultdict
from edges_cal import xrfi
from multiprocess import Pool, cpu_count, current_process
from multiprocessing.sharedctypes import RawArray
logger = logging.getLogger(__name__)
_globals = {}
def _init_worker(spectrum, weights, shape):
# This just shoves things into _globals so that each worker in a pool hass access
# to them. If they are in shared memory space (such as a RawArray), then they are
# not copied to each process, just accessed therefrom.
_globals["spectrum"] = spectrum
_globals["weights"] = weights
_globals["shape"] = shape
[docs]
def join_struct_arrays(arrays):
"""Join a list of structured numpy arrays (make new columns)."""
dtype = sum((a.dtype.descr for a in arrays), [])
out = np.empty(len(arrays[0]), dtype=dtype)
for a in arrays:
for name in a.dtype.names:
out[name] = a[name]
return out
[docs]
def run_xrfi(
*,
method: str,
spectrum: np.ndarray,
freq: np.ndarray,
weights: np.ndarray | None = None,
flags: np.ndarray | None = None,
n_threads: int = cpu_count(),
fl_id=None,
**kwargs,
) -> np.ndarray:
"""Run an xrfi method on given spectrum and weights."""
rfi = getattr(xrfi, f"xrfi_{method}")
if weights is None:
if flags is None:
weights = np.ones_like(spectrum)
else:
weights = (~flags).astype(float)
if flags is not None:
weights = np.where(flags, 0, weights)
if spectrum.ndim in rfi.ndim:
flags = rfi(spectrum, weights=weights, **kwargs)[0]
elif spectrum.ndim > max(rfi.ndim) + 1:
# say we have a 3-dimensional spectrum but can only do 1D in the method.
# then we collapse to 2D and recursively run xrfi_pipe. That will trigger
# the *next* clause, which will do parallel mapping over the first axis.
orig_shape = spectrum.shape
new_shape = (-1,) + orig_shape[2:]
flags = run_xrfi(
spectrum=spectrum.reshape(new_shape),
weights=weights.reshape(new_shape),
freq=freq,
method=method,
n_threads=n_threads,
**kwargs,
)
return flags.reshape(orig_shape)
else:
n_threads = min(n_threads, len(spectrum))
# Use a parallel map unless this function itself is being called by a
# parallel map.
wrns = defaultdict(int)
def count_warnings(message, *args, **kwargs):
wrns[str(message)] += 1
old = warnings.showwarning
warnings.showwarning = count_warnings
if current_process().name == "MainProcess" and n_threads > 1:
def fnc(i):
# Gets the spectrum/weights from the global var dict, which was
# initialized by the pool.
# See https://research.wmz.ninja/articles/2018/03/on-sharing-large-
# arrays-when-using-pythons-multiprocessing.html
spec = np.frombuffer(_globals["spectrum"]).reshape(_globals["shape"])[i]
wght = np.frombuffer(_globals["weights"]).reshape(_globals["shape"])[i]
if np.any(wght > 0):
return rfi(spec, freq=freq, weights=wght, **kwargs)[0]
else:
return np.ones_like(spec, dtype=bool)
shared_spectrum = RawArray("d", spectrum.size)
shared_weights = RawArray("d", spectrum.size)
# Wrap X as an numpy array so we can easily manipulates its data.
shared_spectrum_np = np.frombuffer(shared_spectrum).reshape(spectrum.shape)
shared_weights_np = np.frombuffer(shared_weights).reshape(spectrum.shape)
# Copy data to our shared array.
np.copyto(shared_spectrum_np, spectrum)
np.copyto(shared_weights_np, weights)
p = Pool(
n_threads,
initializer=_init_worker,
initargs=(shared_spectrum, shared_weights, spectrum.shape),
)
m = p.map
else:
def fnc(i):
if np.any(weights[i] > 0):
return rfi(spectrum[i], freq=freq, weights=weights[i], **kwargs)[0]
else:
return np.ones_like(spectrum[i], dtype=bool)
m = map
results = m(fnc, range(len(spectrum)))
flags = np.array(list(results))
warnings.showwarning = old
# clear global memory (not sure if it still exists)
_init_worker(0, 0, 0)
fl_id = f"{fl_id}: " if fl_id else ""
if wrns:
for msg, count in wrns.items():
msg = msg.replace("\n", " ")
logger.warning(
f"{fl_id}Received warning '{msg}' {count}/{len(flags)} times."
)
return flags
[docs]
def slice_along_axis(x: np.ndarray, idx: np.ndarray | slice, axis: int = -1):
"""Get a view of x at indices idx on a given axis."""
from_end = False
if axis < 0: # choosing axis at the end
from_end = True
axis = -1 - axis
explicit_inds_slice = axis * (slice(None),)
if from_end:
return x[(Ellipsis, idx) + explicit_inds_slice]
else:
return x[explicit_inds_slice + (idx,)]