Source code for edges.modeling.xtransforms

"""Module defining x-variable transforms for modelling."""

from abc import ABCMeta, abstractmethod
from functools import cached_property
from typing import Self

import attrs
import numpy as np
import yaml

from ..io.serialization import hickleable


def _transform_yaml_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
    mapping = loader.construct_mapping(node, deep=True)
    return XTransform.get(node.tag[1:])(**mapping)


def _transform_yaml_representer(dumper: yaml.SafeDumper, tr) -> yaml.nodes.MappingNode:
    dct = attrs.asdict(tr, recurse=False)
    return dumper.represent_mapping(f"!{tr.__class__.__name__}", dct)


[docs] @hickleable @attrs.define(frozen=True, kw_only=True, slots=False) class XTransform(metaclass=ABCMeta): """Abstract base class for all coordinate transforms.""" _models = {} def __init_subclass__(cls, is_meta=False, **kwargs): """Initialize a subclass and add it to the registered models.""" super().__init_subclass__(**kwargs) yaml.add_constructor(f"!{cls.__name__}", _transform_yaml_constructor) if not is_meta: cls._models[cls.__name__.lower()] = cls
[docs] @abstractmethod def transform(self, x: np.ndarray) -> np.ndarray: """Transform the coordinates."""
[docs] @classmethod def get(cls, model: str) -> type[Self]: """Get a ModelTransform class.""" return cls._models[model.lower()]
def __call__(self, x: np.ndarray) -> np.ndarray: """Transform the coordinates.""" return self.transform(x) def __getstate__(self): """Get the state for pickling.""" return attrs.asdict(self)
[docs] @hickleable @attrs.define(frozen=True, kw_only=True, slots=False) class IdentityTransform(XTransform): """A transform that does nothing."""
[docs] def transform(self, x: np.ndarray) -> np.ndarray: """Transform the coordinates.""" return x
[docs] @hickleable @attrs.define(frozen=True, kw_only=True, slots=False) class ScaleTransform(XTransform): """A transform that scales the coordinates by a single factor. Parameters ---------- scale The scale factor to apply to the coordinates. The resulting coordinates will be ``original/scale``. """ scale: float = attrs.field(converter=float)
[docs] def transform(self, x: np.ndarray) -> np.ndarray: """Transform the coordinates.""" return x / self.scale
[docs] def tuple_converter(x): """Convert input to tuple of floats.""" return tuple(float(xx) for xx in x)
[docs] @hickleable @attrs.define(frozen=True, kw_only=True, slots=False) class CentreTransform(XTransform): """A transform that shifts the coordinates to a new centre point. The new coordinates will be centered at the mid-point of ``range`` plus ``centre``. Parameters ---------- range The range of the input coordinates. centre The new centre point for the coordinates. """ range: tuple[float, float] = attrs.field(converter=tuple_converter) centre: float = attrs.field(default=0.0, converter=float)
[docs] def transform(self, x: np.ndarray) -> np.ndarray: """Transform the coordinates.""" return x - self.range[0] - (self.range[1] - self.range[0]) / 2 + self.centre
[docs] @hickleable @attrs.define(frozen=True, kw_only=True, slots=False) class ShiftTransform(XTransform): """A transform that shifts the coordinates by a fixed amount. Parameters ---------- shift The amount to shift the coordinates by. """ shift: float = attrs.field(converter=float, default=0.0)
[docs] def transform(self, x: np.ndarray) -> np.ndarray: """Transform the coordinates.""" return x - self.shift
[docs] @hickleable @attrs.define(frozen=True, kw_only=True, slots=False) class UnitTransform(XTransform): """A transform that takes the input range down to -1 to 1. Parameters ---------- range The range that is rescaled to (-1, 1). This need not be the range of the actual data, so that the final coordinates may extend past (-1, 1). """ range: tuple[float, float] = attrs.field(converter=tuple_converter) @cached_property def _centre(self): return CentreTransform(centre=0, range=self.range)
[docs] def transform(self, x: np.ndarray) -> np.ndarray: """Transform the coordinates.""" return 2 * self._centre.transform(x) / (self.range[1] - self.range[0])
[docs] @hickleable @attrs.define(frozen=True, kw_only=True, slots=False) class LogTransform(XTransform): """A transform that takes the logarithm of the input. The final coordinates will be ``log(original / scale)``. """ scale: float = attrs.field(default=1.0)
[docs] def transform(self, x: np.ndarray) -> np.ndarray: """Transform the coordinates.""" return np.log(x / self.scale)
[docs] @hickleable @attrs.define(frozen=True, kw_only=True, slots=False) class Log10Transform(LogTransform): """A transform that takes the base10 logarithm of the input. The final coordinates will be ``log10(original / scale)``. """
[docs] def transform(self, x: np.ndarray) -> np.ndarray: """Transform the coordinates.""" return np.log10(x / self.scale)
[docs] @hickleable @attrs.define(frozen=True, kw_only=True, slots=False) class ZerotooneTransform(UnitTransform): """A transform that takes an input range down to (0,1). Like :class:`UnitTransform`, but rescales to (0, 1). """
[docs] def transform(self, x: np.ndarray) -> np.ndarray: """Transform the coordinates.""" return (x - self.range[0]) / (self.range[1] - self.range[0])
yaml.add_multi_representer(XTransform, _transform_yaml_representer)