Source code for segram.datastruct.collections

"""Enhanced :mod:`collections.abc` classes implementing
generic data filtering and transformation method.
"""
from typing import Self, Any, Callable, Literal
from typing import Iterable, Iterator, Sequence, Mapping
from types import MethodType
from abc import abstractmethod
from functools import total_ordering
from itertools import groupby, product, islice
from more_itertools import unique_everseen


[docs] class DataABC(Iterable): """Abstract base class for data classes.""" __slots__ = () @abstractmethod def __iter__(self) -> Iterable: pass
[docs] def pipe( self, func: str | Callable[[Self, ...], Any], *args: Any, **kwds: Any ) -> Any: """Pipe self to a function.""" func = self._handle_string_func(func) return func(self, *args, **kwds)
# Internals --------------------------------------------------------------- @staticmethod def _handle_string_func(func: str | Callable) -> Callable: if isinstance(func, str): def _func(o, *args, **kwds): return getattr(o, func)(*args, **kwds) return _func return func @staticmethod def _get_keyfunc(func: str | Callable, *args: Any, **kwds: Any) -> Callable: def keyfunc(obj): nonlocal func key = func if isinstance(key, str): key = getattr(obj, key) if isinstance(key, MethodType): key = key(*args, **kwds) elif isinstance(key, Callable): key = key(obj, *args, **kwds) return key return keyfunc
[docs] class DataIterable(DataABC): """Abstract base class for data iterables.""" # pylint: disable=abstract-method # Properties -------------------------------------------------------------- @property def flat(self) -> Self: return self.__class__(self.iter_flat()) @property def list(self) -> "DataList": return self.pipe(DataList) @property def tuple(self) -> "DataTuple": return self.pipe(DataTuple) # Methods -----------------------------------------------------------------
[docs] def get(self, attr: str) -> Self: """Extract attributes from data items.""" return self.map(lambda m: getattr(m, attr))
def any(self) -> bool: return any(self) def all(self) -> bool: return all(self)
[docs] def map(self, func: str | Callable[[Any, ...], Any], *args: Any, **kwds: Any) -> Self: """Map data iterator.""" func = self._handle_string_func(func) return self.__class__(func(x, *args, **kwds) for x in self)
[docs] def filter( self, func: str | Callable[[Any, ...], bool] | None, *args: Any, **kwds: Any ) -> Self: """Filter data iterator.""" if func is None: func = bool else: func = self._handle_string_func(func) return self.__class__(x for x in self if func(x, *args, **kwds))
[docs] def unique(self, key: str | Callable[[Any, ...], Any] | None = None) -> Self: """Return unique values (only first unique occurences are returned).""" return self.__class__(self.pipe(unique_everseen, key=key))
[docs] def groupby(self, *args: Any, **kwds: Any) -> Self: """Group by key attribute or function/method. Importantly, the key function/values must be sortable. Parameters ---------- *args First argument is interpreted as key function (or its name). The rest is passed as actual ``*args`` to the function. No grouping is done if no function/name is passed. **kwds Passed to the function. """ if not args: return self key, *args = args keyfunc = self._get_keyfunc(key, **kwds) groups = DataDict() data = DataTuple(self) for k, g in data.sort(key, **kwds).pipe(groupby, key=keyfunc): groups[k] = DataTuple(g) return groups
[docs] def zip(self, iterable: Iterable) -> Self: """Zip with other iterable.""" return self.__class__(zip(self, iterable))
def iter_flat(self) -> Iterable: for obj in self: if isinstance(obj, DataIterable | tuple | list): yield from obj else: yield obj
[docs] class DataIterator(Iterator, DataIterable): """Data iterators class.""" # pylint: disable=abstract-method __slots__ = ("__data__",) def __init__(self, data: Iterable, /) -> None: self.__data__ = iter(data) def __next__(self) -> Any: return next(self.__data__) def __getitem__(self, idx: int | slice) -> Any | Self: if isinstance(idx, slice): start = idx.start stop = idx.stop step = idx.step return self.__class__(islice(self, start, stop, step)) return next(islice(self, idx, idx+1))
[docs] @total_ordering class DataSequence(Sequence, DataIterable): """Data sequence class.""" @abstractmethod def __repr__(self) -> str: pass @abstractmethod def __getitem__(self, idx: int | slice) -> Any | Self: pass @abstractmethod def __len__(self) -> int: pass @abstractmethod def __eq__(self, other: Any) -> bool: pass @abstractmethod def __lt__(self, other: Any) -> bool: pass
[docs] def pairwise(self) -> Iterable[tuple[Any, Any]]: """Iterate over all pairs of data items.""" return self.__class__(product(self, self))
[docs] def sort( self, *args: Any, reverse: bool = False, show_keys: bool = False, **kwds: Any ) -> Self: """Sort elements. It is typically best to first flatten the sequence in case it contains nested sequences. Parameters ---------- *args Name of an attribute or a method defined on items. Alternatively a callable. Further positional arguments are passed to the key function. If no positional arguments are used then standard data item sorting is used. Alternatively, an iterable of values used for sorting can be passed. show_keys Should sorting key values be returned together with the objects (so 2-tuples are returned). **kwds Passed to the sorting callable. """ if args: key, *args = args if isinstance(key, str | Callable): keyfunc = self._get_keyfunc(key, *args, **kwds) else: keyfunc = key else: keyfunc = None if isinstance(keyfunc, Callable): data = sorted(self, key=keyfunc, reverse=reverse) if show_keys: data = zip(sorted(self.map(keyfunc), reverse=reverse), data) elif isinstance(keyfunc, Iterable): keyfunc = DataTuple(keyfunc) if len(self) != len(keyfunc): raise ValueError( "sequence used for sorting must be of the same length as data" ) data = sorted(zip(keyfunc, self), key=lambda x: x[0], reverse=reverse) if not show_keys: data = [ x[1] for x in data ] else: raise ValueError( "'key' must be a callable or its name or " "an iterable of sorting values" ) return self.__class__(data)
[docs] class DataMapping(Mapping, DataABC): """Abstract base class for data mappings.""" # pylint: disable=abstract-method _what_vals = ("items", "keys", "values")
[docs] def keys(self) -> DataSequence: return DataTuple(super().keys())
[docs] def values(self) -> DataSequence: return DataTuple(super().values())
[docs] def items(self) -> DataSequence[tuple[Any, Any]]: return DataTuple(super().items())
[docs] def map(self, _what: Literal[*_what_vals], *args: Any, **kwds: Any) -> Self: """Map over keys, values or items and return a transformed dictionary. Parameters ---------- _what Part of dictionary to process. *args, **kwds Passed to :meth:`DataIteratorABC.map`. """ if _what not in self._what_vals: raise ValueError( f"data dictionary can be mapped only over one of: {self._what_vals}" ) if _what == "items": return self.__class__(self.items().map(*args, **kwds)) keys = self.keys() vals = self.values() if _what == "keys": keys = keys.map(*args, **kwds) else: vals = vals.map(*args, **kwds) return self.__class__(zip(keys, vals))
[docs] def filter(self, _what: Literal[*_what_vals], *args: Any, **kwds: Any) -> Self: """Filter dictonary by keys, values or items. Parameters ---------- _what Part of dictionary to process. *args, **kwds Passed to :meth:`DataIteratorABC.map`. """ if _what not in self._what_vals: raise ValueError( f"data dictionary can be filtered only over one of: {self._what_vals}" ) if _what == "items": return self.__class__(self.items().filter(*args, **kwds)) if args: func, *args = args if _what == "keys": flt = lambda item: func(item[0], *args, **kwds) else: flt = lambda item: func(item[1], *args, **kwds) return self.__class__(self.items().filter(flt))
[docs] def sort(self, _what: Literal[*_what_vals], *args: Any, **kwds: Any) -> Self: """Sort dictionary. Parameters ---------- _what Part of dictionary to process. *args, **kwds Passed to :meth:`DataSequenceABC.sort`, which is called on ``self.items()``. """ return self._apply("sort", _what, *args, **kwds)
# Internals --------------------------------------------------------------- def _apply(self, __method__: str, _what: str, *args: Any, **kwds: Any) -> Self: if _what not in self._what_vals: raise ValueError( f"data dictionary can be filtered only over one of: {self._what_vals}" ) if args: func, *args = args else: func = None if _what == "keys": idx = 0 elif _what == "values": idx = 1 else: idx = None def flt(item, idx, func, *args, **kwds): if idx is not None: item = item[idx] if func is not None: item = func(item, *args, **kwds) return item method = getattr(self.items(), __method__) return self.__class__(method(flt, idx, func, *args, **kwds))
[docs] class DataTuple(tuple, DataSequence): """Data tuple class."""
[docs] class DataList(list, DataSequence): """Data list class."""
[docs] class DataDict(dict, DataMapping): """Data dict class.""" keys = DataMapping.keys values = DataMapping.values items = DataMapping.items