Source code for datadings.reader.list

from typing import Union
from typing import Sequence
from typing import Iterable
from typing import Callable
from pathlib import Path

from natsort import natsorted

from .reader import Reader
from ..tools.msgpack import packb


[docs]def load_lines(path): with open(path, encoding='utf-8') as f: return [l.strip('\n ') for l in f]
canary = object()
[docs]def sorted_labels(samples): # try to get label, but with canary instead of None as default labels = set(s.get('label', canary) for s in samples) if None in labels: raise ValueError('Found None as label. ' 'Labels must be given to reader to use None.') # remove canaries from the mine if canary in labels: labels.remove(canary) return natsorted(labels)
[docs]def noop(_): pass
[docs]class ListReader(Reader): """ Reader that holds a list of samples. Functions can be given to load data on the fly and/or perform further conversion steps. Two special keys ``"key"`` and ``"label"`` in samples are used: - ``"key"`` is a unique identifier for samples. Sample index is added to samples if it is missing. - ``"label"`` holds an optional label. Replaced by a numeric index to the list of labels if ``numeric_labels`` is true. The original label is retained as ``"_label"``. Note: If ``labels`` argument is not given, the list of all labels will be extracted from all samples. The list of all labels is :py:func:`natsorted <natsort.natsorted>` to determine numerical labels. Note: ``initfun`` is applied to the given samples during initialization and thus remain for the life of the reader. ``convertfun`` is applied to a shallow copy of the sample every time before it is returned. Important: Since ``None`` is not sortable, the ``labels`` argument must be given to use ``None`` as a label. Parameters: samples: Sequence of samples. Must be indexable, so no generators or one-time iterators. labels: Optional. List of labels in desired order, or path to file with one label per line. If ``None``, get ``"label"`` keys from samples, if any, and sort. numeric_labels: If true, convert labels to numeric index to list of all labels. initfun: Callable ``initfun(sample: dict)`` to modify samples in-place during initialization. convertfun: Callable ``convertfun(sample: dict)``. Applied to shallow copies of samples before they are returned. convertfun: Callable ``convertfun(sample: dict)`` to modify a shallow copy of samples in-place before they are returned. """ def __init__( self, samples: Sequence[dict], labels: Union[Iterable, Path] = None, numeric_labels=True, initfun: Callable = noop, convertfun: Callable = noop, ): super().__init__() self._convertfun = convertfun self._samples = samples self._index = {} for i, sample in enumerate(self._samples): initfun(sample) key = sample.get('key', i) if key in self._index: raise ValueError(f'duplicate key {key!r}') sample['key'] = key self._index[key] = i self.labels = labels or sorted_labels(self._samples) try: self.labels = load_lines(labels) except TypeError: pass self._numeric_labels = numeric_labels if numeric_labels: self._label_index = {str(l): i for i, l in enumerate(self.labels)} self._label_index.update({l: i for i, l in enumerate(self.labels)}) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): pass def __len__(self): return len(self._samples) def __contains__(self, key): return key in self._index
[docs] def find_key(self, index): return self._samples[index]['key']
[docs] def find_index(self, key): return self._index[key]
[docs] def get(self, index, yield_key=False, raw=False, copy=True): sample = dict(self._samples[index]) # load and convert sample if self._numeric_labels and 'label' in sample: sample['_label'] = sample['label'] sample['label'] = self._label_index[sample['label']] self._convertfun(sample) # return sample key = sample['key'] if raw: sample = packb(sample) if yield_key: return key, sample else: return sample
[docs] def slice(self, start, stop=None, yield_key=False, raw=False, copy=True): start, stop, _ = slice(start, stop).indices(len(self._samples)) for index in range(start, stop): yield self.get(index, yield_key=yield_key, raw=raw)