from typing import Sequence
from eotransform.protocol.transformer import Transformer
from eotransform_xarray.functional.load_file_dataframe_to_array import CONCATED_ATTRS_KEY
from eotransform_xarray.transformers import XArrayData
BATCHED_ATTRS_KEY = 'batched_attrs'
[docs]class Batch(Transformer[XArrayData, Sequence[XArrayData]]):
def __init__(self, size: int, batch_dim: str):
self._size = size
self._batch_dim = batch_dim
def __call__(self, x: XArrayData) -> Sequence[XArrayData]:
n_samples = x.sizes[self._batch_dim]
batches = [x.isel({self._batch_dim: slice(i, min(i + self._size, n_samples))})
for i in range(0, n_samples, self._size)]
if CONCATED_ATTRS_KEY in x.attrs:
self._split_concated_attributes_to_batches(x, batches, n_samples)
return batches
def _split_concated_attributes_to_batches(self, x, batches, n_samples):
concated_attrs = x.attrs[CONCATED_ATTRS_KEY]
n_concated = len(concated_attrs)
assert n_concated == n_samples, f"length of concatenated attributes {n_concated}," \
f" must match length of input array {n_samples}"
split_attrs = [concated_attrs[i:min(i + self._size, n_concated)] for i in range(0, n_concated, self._size)]
for batch, attr in zip(batches, split_attrs):
batch.attrs = batch.attrs.copy()
batch.attrs[BATCHED_ATTRS_KEY] = attr
del batch.attrs[CONCATED_ATTRS_KEY]