"""
An Augment wraps a
:py:class:`Reader <datadings.reader.reader.Reader`
and changes how samples are iterated over.
How readers are used is largely unaffected.
"""
from abc import ABCMeta, abstractmethod
from math import ceil
from random import Random
__all__ = ('Range', 'Repeater', 'Cycler', 'Shuffler', 'QuasiShuffler')
class Augment(object):
"""
Base class for Augments.
Warning:
Augments are not thread safe!
Parameters:
reader: the reader to augment
"""
__metaclass__ = ABCMeta
def __init__(self, reader):
self._reader = reader
def __enter__(self):
self._reader.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._reader.__exit__(exc_type, exc_val, exc_tb)
def __len__(self):
return len(self._reader)
def __iter__(self):
return self.iter()
@abstractmethod
def iter(
self,
yield_key=False,
raw=False,
copy=True,
chunk_size=16,
):
"""
Create an iterator.
Parameters:
yield_key: if True, yields (key, sample) pairs.
raw: if True, yields samples as msgpacked messages.
copy: if False, allow the reader to return data as
``memoryview`` objects instead of ``bytes``
chunk_size: number of samples read at once;
bigger values can increase throughput,
but also memory
Returns:
Iterator
"""
pass
def rawiter(self, yield_key=False):
"""
Create an iterator that yields samples as msgpacked messages.
Order and number of samples is determined by the Augment.
Included for backwards compatibility and may be deprecated and
subsequently removed in the future.
Parameters:
yield_key: If True, yields (key, sample) pairs.
Returns:
Iterator
"""
return self.iter(yield_key=yield_key, raw=True)
@abstractmethod
def seek(self, index):
pass
[docs]class Range(Augment):
"""
Extract a range of samples from a given reader.
``start`` and ``stop`` behave like the parameters of the
``range`` function.
Parameters:
reader: reader to sample from
start: start of range
stop: stop of range
"""
def __init__(self, reader, start=0, stop=None):
super().__init__(reader)
n = len(reader)
if start < 0:
start += n
if start < 0 or start >= n:
raise IndexError(f'index {start} out of range for length {n} reader')
self.start, self.stop, _ = slice(start, stop).indices(n)
[docs] def iter(
self,
yield_key=False,
raw=False,
copy=True,
chunk_size=16,
):
return self._reader.iter(
start=self.start,
stop=self.stop,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
[docs] def seek(self, index):
self._reader.seek(self.start + index)
[docs]class Shuffler(Augment):
"""
Iterate over a
:py:class:`Reader <datadings.reader.reader.Reader` in random order.
Parameters:
reader: The reader to augment.
seed: optional random seed; defaults to len(reader)
Warning:
Augments are not thread safe!
"""
def __init__(self, reader, seed=None):
super().__init__(reader)
self._n = len(reader)
self._seed = self._n if seed is None else seed
self._offset = 0
self._i = 0
[docs] def seek(self, index):
self._i = index
self._offset = index // self._n * self._n
[docs] def iter(
self,
yield_key=False,
raw=False,
copy=True,
chunk_size=16,
):
n = self._n
rand = Random()
rand.seed(self._seed + self._offset, version=2)
order = list(range(n))
rand.shuffle(order)
for i in order[self._i:]:
yield self._reader.get(i, yield_key=yield_key, raw=raw, copy=copy)
self._i += 1
self._i = 0
self._offset += self._n
class _Placeholder(int):
def __eq__(self, other):
return int.__eq__(self, other)
[docs]class QuasiShuffler(Augment):
"""
A slightly less random than a true
:py:class:`Reader <datadings.reader.augment.Shuffler` but much faster.
The dataset is divided into equal-size chunks that are read in random
order.
Shuffling follows these steps:
1. Fill the buffer with chunks.
2. Read the next chunk.
3. Select a random sample from the buffer and yield it.
4. Replace the sample with the next sample from the current chunk.
5. If there are chunks left, goto 2.
This means there are typically more samples from the current chunk
in the buffer than there would be if a true shuffle was used.
This effect is more pronounced for smaller fractions :math:`\\frac{B}{C}`
where :math:`C` is the chunk size and :math:`B` the buffer size.
As a rule of thumb it is sufficient to keep :math:`\\frac{B}{C}` roughly
equal to the number of classes in the dataset.
Note:
Seeking and resuming iteration with a new iterator are relatively
costly operations. If possible create one iterator and use it
repeatedly.
Parameters:
reader: the reader to wrap
buf_size: size of the buffer; values less than 1 are interpreted
as fractions of the dataset length; bigger values improve
randomness, but use more memory
chunk_size: size of each chunk; bigger values improve performance,
but reduce randomness
seed: random seed to use;
defaults to ``len(reader) * self.buf_size * chunk_size``
"""
def __init__(self, reader, buf_size=0.01, chunk_size=16, seed=None):
super().__init__(reader)
self._i = 0
self._n = len(reader)
if buf_size < 1:
buf_size = ceil(self._n * 0.01)
buf_size = int(buf_size)
self.reader = reader
# buf size is a multiple of chunk_size
self.buf_size = int(ceil(buf_size / chunk_size)) * chunk_size
self.chunk_size = chunk_size
self.num_chunks = ceil(self._n / chunk_size)
self._seed = len(reader) * self.buf_size * chunk_size if seed is None else seed
self._offset = 0
[docs] def seek(self, index):
if index < 0:
raise IndexError('index must be > 0')
self._i = index % self._n
self._offset = index // self._n * self._n
# noinspection PyMethodOverriding
[docs] def iter(
self,
yield_key=False,
raw=False,
copy=True,
chunk_size=None,
):
chunk_size = chunk_size or self.chunk_size
rand = Random()
rand.seed(self._seed + self._offset, version=2)
chunk_order = list(range(self.num_chunks))
rand.shuffle(chunk_order)
chunks = ((
c * chunk_size,
min(self._n, (c + 1) * chunk_size)
) for c in chunk_order)
reader = self.reader
# create buffer
buffer = []
# for index < buffer size, fill buffer with actual data
if self._i < self.buf_size:
for _, (a, b) in zip(range(self.buf_size // chunk_size), chunks):
buffer.extend(reader.slice(a, b, yield_key=yield_key, raw=raw))
# for larger index, fill with placeholders
else:
for _, (a, b) in zip(range(self.buf_size // chunk_size), chunks):
buffer.extend(map(_Placeholder, range(a, b)))
# buffer may be smaller than requested if last chunk is used and
# dataset does not cleanly divide into chunks
buf_size = len(buffer)
i = 0
# yield from remaining chunks
for a, b in chunks:
index = a
# store placeholders until current index is reached
if i < self._i:
for index in range(a, b):
if i >= self._i:
break
buffer_pos = rand.randrange(buf_size)
buffer[buffer_pos] = _Placeholder(index)
i += 1
# once index is reached, read samples from reader
if i >= self._i:
for sample in reader.slice(index, b, yield_key=yield_key, raw=raw):
buffer_pos = rand.randrange(buf_size)
buffer_value = buffer[buffer_pos]
if type(buffer_value) is _Placeholder:
buffer_value = reader.get(buffer_value, yield_key=yield_key, raw=raw)
yield buffer_value
buffer[buffer_pos] = sample
self._i += 1
i += 1
# yield rest of buffer
buffer_start = max(0, buf_size - self._n + self._i)
for buffer_value in buffer[buffer_start:]:
if type(buffer_value) is _Placeholder:
buffer_value = reader.get(buffer_value, yield_key=yield_key, raw=raw)
yield buffer_value
self._i += 1
self._i = 0
self._offset += self._n
[docs]class Repeater(Augment):
"""
Repeat a :py:class:`Reader <datadings.reader.reader.Reader`
a fixed number of times.
Warning:
Augments are not thread safe!
"""
def __init__(self, reader, times):
super().__init__(reader)
self.times = times
[docs] def iter(
self,
yield_key=False,
raw=False,
copy=True,
chunk_size=16,
):
for _ in range(self.times):
yield from self._reader.iter(
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
[docs] def seek(self, index):
self._reader.seek(index % len(self._reader))
[docs]class Cycler(Augment):
"""
Infinitely cycle a :py:class:`Reader <datadings.reader.reader.Reader`.
Warning:
Augments are not thread safe!
"""
[docs] def iter(
self,
yield_key=False,
raw=False,
copy=True,
chunk_size=16,
):
while 1:
yield from self._reader.iter(
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
[docs] def seek(self, index):
self._reader.seek(index % len(self._reader))