Source code for eotransform_xarray.transformers.random_sample

from typing import Tuple, Sequence

import numpy as np
from eotransform.protocol.transformer import Transformer
from more_itertools import repeatfunc

from eotransform_xarray.transformers import XArrayData


[docs]class RandomSample(Transformer[XArrayData, Sequence[XArrayData]]): def __init__(self, n, ignore_extent=None): self._n = n self._ignore_extent_name = ignore_extent def __call__(self, x: XArrayData) -> Sequence[XArrayData]: patch_size = x.attrs['patch_size'] extent = x.sizes['y'], x.sizes['x'] has_extent_to_ignore = self._ignore_extent_name is not None if has_extent_to_ignore: ignore_extent = x.attrs['extents'][self._ignore_extent_name] valid_points = _find_valid_indices(patch_size, extent, ignore_extent) def _random_indices(): if has_extent_to_ignore: i, j = valid_points[:, np.random.randint(valid_points.shape[1])] else: i = np.random.randint(0, extent[0] - patch_size) j = np.random.randint(0, extent[1] - patch_size) return i, j return [_sample_patch_from(x, position, patch_size) for position in repeatfunc(_random_indices, self._n)]
def _find_valid_indices(patch_size, extent, ignore_extent): ii, jj = np.meshgrid(range(0, extent[0] - patch_size), range(0, extent[1] - patch_size), indexing='ij') mask = np.zeros_like(ii, dtype=bool) left, top, right, bottom = ignore_extent left = max(left - patch_size, 0) top = max(top - patch_size, 0) right = min(right, mask.shape[0]) bottom = min(bottom, mask.shape[1]) mask[left:right, top:bottom] = True ii = np.ma.masked_array(ii, mask=mask).compressed() jj = np.ma.masked_array(jj, mask=mask).compressed() return np.stack([ii, jj]) def _sample_patch_from(x: XArrayData, position: Tuple[int, int], patch_size): top, left = position bottom = top + patch_size right = left + patch_size patch = x.isel(y=slice(top, bottom), x=slice(left, right)) patch.attrs = patch.attrs.copy() patch.attrs['sampled_extent'] = (top, left, bottom, right) return patch