Source code for eotransform_xarray.transformers.masking_where

from typing import Union, Callable, Any

import numpy as np
from xarray import DataArray, Dataset

from eotransform_xarray.transformers import TransformerOfXArrayData, XArrayData

MaskingSource = Union[Callable[[Any], Any], DataArray, Dataset]


[docs]class MaskWhere(TransformerOfXArrayData): def __init__(self, predicate: MaskingSource, replacement_value: Any, invert=False): self._predicate = predicate self._replacement_value = replacement_value self._invert = invert def __call__(self, x: XArrayData) -> XArrayData: if self._invert: return x.where(_not(self._predicate), self._replacement_value) return x.where(self._predicate, self._replacement_value)
def _not(predicate: MaskingSource) -> MaskingSource: if callable(predicate): return lambda v: ~predicate(v) return ~(predicate.astype(np.bool_))