Source code for eotransform_xarray.transformers.resample_with_gauss

from abc import abstractmethod
from dataclasses import dataclass, asdict
from typing import Tuple, Union, Literal, Mapping, Any, Optional, Dict, Callable

import numpy as np
import rioxarray  # noqa # pylint: disable=unused-import
import xarray as xr
from affine import Affine
from numpy.typing import NDArray, DTypeLike
from xarray import DataArray, Dataset
from yaml import YAMLObject, SafeLoader

from eotransform_xarray.storage.storage import Storage

try:
    from numba import njit, prange, guvectorize, float64, float32
    from pyresample import SwathDefinition, AreaDefinition
    from pyresample.kd_tree import get_neighbour_info
except ImportError:
    print("ResampleWithGauss requires numba and pyresample.\npip install numba pyresample")
    raise

from eotransform_xarray.transformers import TransformerOfDataArray

EngineType = Literal['dask', 'numba']
EmptyRasterFactory = Callable[[Tuple[int, ...], DTypeLike], NDArray]


[docs]@dataclass class Swath: lons: NDArray lats: NDArray
[docs]@dataclass class Extent: lower_left_x: float lower_left_y: float upper_right_x: float upper_right_y: float def to_tuple(self) -> Tuple[float, float, float, float]: return self.lower_left_x, self.lower_left_y, self.upper_right_x, self.upper_right_y
[docs]@dataclass class Area: name: str projection: str columns: int rows: int extent: Extent transform: Affine description: str = ""
[docs]@dataclass class ProjectionParameter: in_resampling: Dataset out_resampling: Dataset @classmethod def from_storage(cls, storage: Storage) -> "ProjectionParameter": return ProjectionParameter(**{f: v for f, v in storage.load().items()}) def store(self, storage: Storage) -> None: storage.save(asdict(self))
[docs]class MaybePacked: def __init__(self, value: NDArray, is_packed: bool = False): self.value = value self._is_packed = is_packed self._max = value.max() def __or__(self, dtype: DTypeLike) -> "MaybePacked": if self._is_packed: return self if self._max <= np.iinfo(dtype).max: return MaybePacked(self.value.astype(dtype), True) else: return self
[docs]class StorageIntoTheVoid(Storage): def exists(self) -> bool: return False def load(self) -> Mapping[str, Any]: raise NotImplementedError("Can't load from the void.") def save(self, data: Mapping[str, Any]) -> None: pass
[docs]class EngineConfig(YAMLObject): @property @abstractmethod def type(self) -> EngineType: ...
[docs]@dataclass class DaskConfig(EngineConfig): yaml_tag = "!ResampleDaskConfig" yaml_loader = SafeLoader raster_chunk_sizes: Tuple[int, int] @property def type(self) -> EngineType: return "dask"
[docs]@dataclass class NumbaConfig(EngineConfig): yaml_tag = "!ResampleNumbaConfig" yaml_loader = SafeLoader @property def type(self) -> EngineType: return "numba"
[docs]@dataclass class ProcessingConfig(YAMLObject): yaml_tag = "!ResampleProcessing" yaml_loader = SafeLoader num_parameter_calc_procs: int = 1 num_lookup_segments: Optional[int] = None parameter_storage: Optional[Storage] = None resampling_engine: Optional[Union[NumbaConfig, DaskConfig]] = None load_in_resampling_params: Optional[Dict] = None load_out_resampling_params: Optional[Dict] = None
[docs]class ResampleWithGauss(TransformerOfDataArray): class MismatchError(ValueError): ... def __init__(self, swath_src: Swath, area_dst: Area, sigma: float, neighbours: int, lookup_radius: float, processing_config: Optional[ProcessingConfig] = None): self._area_dst = area_dst self._proc_cfg = processing_config or ProcessingConfig() _fill_process_cfg_with_defaults(self._proc_cfg) if self._proc_cfg.parameter_storage.exists(): self._projection_params = ProjectionParameter.from_storage(self._proc_cfg.parameter_storage) else: self._projection_params = self._calc_projection(swath_src, area_dst, neighbours, lookup_radius) self._projection_params.store(self._proc_cfg.parameter_storage) if self._proc_cfg.load_in_resampling_params is not None: self._projection_params.in_resampling = self._projection_params.in_resampling.load( **self._proc_cfg.load_in_resampling_params) if self._proc_cfg.load_out_resampling_params is not None: self._projection_params.out_resampling = self._projection_params.out_resampling.load( **self._proc_cfg.load_out_resampling_params) self._projection_params.out_resampling['weights'] = \ self._distances_to_gauss_weights(self._projection_params.out_resampling['weights'], sigma) def _calc_projection(self, swath: Swath, area: Area, neighbours: int, lookup_radius: float) -> ProjectionParameter: sw_def = SwathDefinition(swath.lons.swapaxes(0, -1), swath.lats.swapaxes(0, -1)) ar_def = AreaDefinition(area.name, area.description, "proj_id", area.projection, area.columns, area.rows, area.extent.to_tuple()) val_in_idc, val_out_idc, idc, distances = get_neighbour_info(sw_def, ar_def, lookup_radius, neighbours, nprocs=self._proc_cfg.num_parameter_calc_procs, segments=self._proc_cfg.num_lookup_segments) packed_idc = MaybePacked(idc) | np.uint8 | np.uint16 | np.uint32 | np.uint64 packed_idc = packed_idc.value.reshape((area.rows, area.columns, -1)) distances = distances.astype(np.float32).reshape((area.rows, area.columns, -1,)) out_mask = val_out_idc[..., np.newaxis].reshape((area.rows, area.columns, -1,)) in_resampling = Dataset({'mask': (('location', 'cell'), val_in_idc[..., np.newaxis])}, coords={'lon': ('location', swath.lons[0]), 'lat': ('location', swath.lats[0])}) \ .chunk({'cell': -1, 'location': -1}) out_resampling = Dataset({'indices': (('y', 'x', 'neighbours'), packed_idc), 'weights': (('y', 'x', 'neighbours'), distances), 'mask': (('y', 'x', 'cell'), out_mask)}) \ .rio.write_crs(area.projection).rio.write_transform(area.transform) if self._proc_cfg.resampling_engine.type == 'numba': out_resampling = out_resampling.chunk({'neighbours': -1, 'cell': -1, 'y': -1, 'x': -1}) else: rc = self._proc_cfg.resampling_engine.raster_chunk_sizes out_resampling = out_resampling.chunk({'neighbours': -1, 'cell': -1, 'y': rc[0], 'x': rc[1]}) return ProjectionParameter(in_resampling, out_resampling) def _distances_to_gauss_weights(self, distances: DataArray, sigma: float) -> DataArray: if self._proc_cfg.resampling_engine.type == 'numba': distances.values = _distance_to_gauss_weight(distances.values, sigma ** 2) return distances else: sig_sqrd = sigma ** 2 return np.exp(-distances ** 2 / sig_sqrd) def __call__(self, x: DataArray) -> DataArray: self._sanity_check_input(x) in_valid = self._projection_params.in_resampling['mask'][:, 0] x = x[..., in_valid.values] indices = self._projection_params.out_resampling['indices'] weights = self._projection_params.out_resampling['weights'] out_valid = self._projection_params.out_resampling['mask'][:, :, 0] if self._proc_cfg.resampling_engine.type == 'numba': resampled = DataArray( _resample_numba(x.values, indices.values, weights.values, out_valid.values), dims=x.dims[:2] + out_valid.dims, coords={**{k: v for k, v in x.coords.items() if k in x.dims[:2]}, **out_valid.coords}) else: resampled = xr.apply_ufunc(_resample_dask, x, indices, weights, out_valid, input_core_dims=[x.dims[-1:], ['neighbours'], ['neighbours'], []], output_dtypes=[x.dtype], dask='parallelized', keep_attrs=True) resampled.attrs = x.attrs return resampled def _sanity_check_input(self, x: DataArray): if self._projection_params.in_resampling['mask'].size != x.shape[-1]: raise ResampleWithGauss.MismatchError("Mismatch between resample transformation projection and input data:" "\nvalid_indices' size doesn't match input data value length:\n" f"{self._projection_params.in_resampling.sizes} != {x.shape}")
def _fill_process_cfg_with_defaults(config: ProcessingConfig): config.parameter_storage = config.parameter_storage or StorageIntoTheVoid() config.resampling_engine = config.resampling_engine or NumbaConfig() @guvectorize([(float32[:], float32, float32[:]), (float64[:], float64, float64[:])], '(),()->()', target='parallel', nopython=True) def _distance_to_gauss_weight(distance, sigma_sqrd, out): out[0] = np.exp(-distance[0] ** 2 / sigma_sqrd) def _resample_dask(in_data: NDArray, indices: NDArray, weights: NDArray, out_valid: NDArray) -> NDArray: in_data = in_data.squeeze((2, 3)) times, parameters = in_data.shape[:2] out = np.empty((times, parameters) + out_valid.shape, dtype=in_data.dtype) _resample_to_single_threaded(in_data, indices, weights, out_valid, out) return out def _resample_numba(in_data: NDArray, indices: NDArray, weights: NDArray, out_valid: NDArray) -> NDArray: times, parameters = in_data.shape[:2] out = np.empty((times, parameters) + out_valid.shape, dtype=in_data.dtype) _resample_to_parallel(in_data, indices, weights, out_valid, out) return out @njit(parallel=False) def _resample_to_single_threaded(in_data: NDArray, indices: NDArray, weights: NDArray, out_valid: NDArray, out: NDArray) -> None: _resample_to_operation(in_data, indices, out, out_valid, weights) @njit(parallel=True) def _resample_to_parallel(in_data: NDArray, indices: NDArray, weights: NDArray, out_valid: NDArray, out: NDArray) -> None: _resample_to_operation(in_data, indices, out, out_valid, weights) @njit(inline='always') def _resample_to_operation(in_data, indices, out, out_valid, weights): times, parameters, in_size = in_data.shape for y in prange(out.shape[-2]): for x in prange(out.shape[-1]): if out_valid[y, x]: for time in range(times): for parameter in range(parameters): weighted_sum = 0 summed_weights = 0 for neighbour in range(indices.shape[-1]): sample_idx = indices[y, x, neighbour] if sample_idx != in_size: sample = in_data[time, parameter, sample_idx] if not np.isnan(sample): w = weights[y, x, neighbour] weighted_sum += w * sample summed_weights += w out[time, parameter, y, x] = weighted_sum / summed_weights if summed_weights > 0 else np.nan else: out[:, :, y, x] = np.nan