"""
An Augment wraps a
:py:class:`Reader <datadings.reader.reader.Reader`
and changes how samples are iterated over.
How readers are used is largely unaffected.
"""
from abc import ABCMeta
import itertools as it
from math import ceil
from random import Random
from .reader import Reader
__all__ = ('Range', 'Repeater', 'Cycler', 'Shuffler', 'QuasiShuffler')
class Augment(Reader, metaclass=ABCMeta):
"""
Base class for Augments.
Warning:
Augments are not thread safe!
Parameters:
reader: the reader to augment
"""
__metaclass__ = ABCMeta
def __init__(self, reader):
super().__init__()
self._len = len(reader)
self._reader = reader
def __enter__(self):
self._reader.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._reader.__exit__(exc_type, exc_val, exc_tb)
def __len__(self):
return self._len
def __contains__(self, key):
return key in self._reader
def find_key(self, index):
return self._reader.find_key(index)
def find_index(self, key):
return self._reader.find_index(key)
def get(self, index, yield_key=False, raw=False, copy=True):
return self._reader.get(index=index, yield_key=yield_key, raw=raw, copy=copy)
def slice(self, start, stop=None, yield_key=False, raw=False, copy=True):
return self._reader.slice(start=start, stop=stop, yield_key=yield_key, raw=raw, copy=copy)
[docs]class Range(Augment):
"""
Extract a range of samples from a given reader.
``start`` and ``stop`` behave like the parameters of the
:python:`range` function.
Parameters:
reader: reader to sample from
start: start of range
stop: stop of range
"""
def __init__(self, reader, start=0, stop=None):
super().__init__(reader)
if start < 0:
start += self._len
if start < 0 or start >= self._len:
raise IndexError(f'index {start} out of range for length {self._len} reader')
self.start, self.stop, _ = slice(start, stop).indices(self._len)
self.n = self.stop - self.start
def __len__(self):
return self.stop - self.start
def __contains__(self, key):
try:
self.find_index(key)
return True
except KeyError:
return False
[docs] def find_key(self, index):
if index >= self.n:
raise IndexError("index out of range")
return self._reader.find_key(index + self.start)
[docs] def find_index(self, key):
i = self._reader.find_index(key)
if i < self.start or i >= self.stop:
raise KeyError(key)
return i - self.start
[docs] def get(self, index, yield_key=False, raw=False, copy=True):
if index > self.n or index < -self.n:
raise IndexError("index out of range")
if index < 0:
index = self.stop - index
else:
index += self.start
return self._reader.get(index=index, yield_key=yield_key, raw=raw, copy=copy)
[docs] def slice(self, start, stop=None, yield_key=False, raw=False, copy=True):
start += self.start
if stop is None:
stop = self.n
stop += self.start
return self._reader.slice(start=start, stop=stop, yield_key=yield_key, raw=raw, copy=copy)
[docs]class Shuffler(Augment):
"""
Iterate over a
:py:class:`Reader <datadings.reader.reader.Reader` in random order.
If no seed is given the length of the reader is used for reproducibility.
Creating an iterator increments the seed by 1.
Use :py:meth:`Shuffler.seed` to set the desired seed instead.
Warning:
Shuffler only implements iteration.
Random access methods ``find_index``, ``find_key``, ``get``,
and ``slice`` raise ``NotImplementedError``.
Parameters:
reader: The reader to augment.
seed: optional random seed; defaults to len(reader)
Warning:
Augments are not thread safe!
"""
def __init__(self, reader, seed=None):
super().__init__(reader)
self._seed = self._len if seed is None else seed
[docs] def seed(self, seed):
self._seed = seed
def _iter_impl(
self,
start,
stop,
yield_key=False,
raw=False,
copy=True,
chunk_size=16,
):
rand = Random()
rand.seed(self._seed, version=2)
self._seed += 1
order = list(range(self._len))
rand.shuffle(order)
for i in order[start:stop]:
yield self._reader.get(i, yield_key=yield_key, raw=raw, copy=copy)
[docs] def find_key(self, index):
raise NotImplementedError("Shuffler does not implement random access")
[docs] def find_index(self, key):
raise NotImplementedError("Shuffler does not implement random access")
[docs] def get(self, index, yield_key=False, raw=False, copy=True):
raise NotImplementedError("Shuffler does not implement random access")
[docs] def slice(self, start, stop=None, yield_key=False, raw=False, copy=True):
raise NotImplementedError("Shuffler does not implement random access")
class _Placeholder(int):
def __eq__(self, other):
return int.__eq__(self, other)
[docs]class QuasiShuffler(Augment):
"""
A slightly less random than a true
:py:class:`Reader <datadings.reader.augment.Shuffler` but much faster.
The dataset is divided into equal-size chunks that are read in random
order.
Shuffling follows these steps:
1. Fill the buffer with random chunks.
2. Read the next random chunk.
3. Select a random sample from the buffer and yield it.
4. Replace the sample with the next sample from the current chunk.
5. If there are chunks left, goto 2.
6. Shuffle the buffer and yield its contents.
This means there are typically more samples from the current chunk
in the buffer than there would be if a true shuffle was used.
This effect is more pronounced for smaller fractions :math:`\\frac{B}{C}`
where :math:`C` is the chunk size and :math:`B` the buffer size.
As a rule of thumb it is sufficient to keep :math:`\\frac{B}{C}` roughly
equal to the number of classes in the dataset.
Note:
Creating a new iterator, especially from a specific start position,
is a costly operation. If possible create one iterator and use it
until it is exhausted.
Parameters:
reader: the reader to wrap
buf_size: size of the buffer; values less than 1 are interpreted
as fractions of the dataset length; bigger values improve
randomness, but use more memory
seed: random seed to use;
defaults to ``len(reader) * buf_size * chunk_size``
"""
def __init__(self, reader, buf_size=0.01, seed=None):
super().__init__(reader)
if buf_size < 1:
buf_size = ceil(self._len * buf_size)
self.buf_size = int(buf_size)
self.reader = reader
self._seed = seed
self._offset = 0
[docs] def seed(self, seed):
self._seed = seed
self._offset = 0
def _iter_impl(
self,
start,
stop,
yield_key=False,
raw=False,
copy=True,
chunk_size=16,
):
self._offset += 1
n = stop - start
# early stop if nothing to do
if n <= 0:
return
# buf size is a multiple of chunk_size
buf_size = int(ceil(self.buf_size / chunk_size)) * chunk_size
num_chunks = ceil(self._len / chunk_size)
rand = Random()
seed = self._seed
if seed is None:
seed = len(self.reader) * self.buf_size * chunk_size
# -1 because we incremented offset earlier
rand.seed(seed + self._offset - 1, version=2)
chunk_order = list(range(num_chunks))
rand.shuffle(chunk_order)
chunks = ((
c * chunk_size,
min(self._len, (c + 1) * chunk_size)
) for c in chunk_order)
reader = self.reader
# create buffer
buffer = []
# for index < buffer size, fill buffer with actual data
if start < buf_size:
for a, b in it.islice(chunks, buf_size // chunk_size):
buffer.extend(reader.slice(a, b, yield_key=yield_key, raw=raw))
# for larger index, fill with placeholders
else:
for a, b in it.islice(chunks, buf_size // chunk_size):
buffer.extend(map(_Placeholder, range(a, b)))
# buffer may be smaller than requested if last chunk is used and
# dataset does not cleanly divide into chunks
buf_size = len(buffer)
i = 0
n = stop - start
# yield from remaining chunks
for a, b in chunks:
index = a
# store placeholders until current index is reached
if i < start:
for index in range(a, b):
if i >= start:
break
buffer_pos = rand.randrange(buf_size)
buffer[buffer_pos] = _Placeholder(index)
i += 1
# once index is reached, read samples from reader
if i >= start:
for sample in reader.slice(index, b, yield_key=yield_key, raw=raw):
if n <= 0:
break
buffer_pos = rand.randrange(buf_size)
buffer_value = buffer[buffer_pos]
if type(buffer_value) is _Placeholder:
buffer_value = reader.get(buffer_value, yield_key=yield_key, raw=raw)
yield buffer_value
buffer[buffer_pos] = sample
i += 1
n -= 1
if n <= 0:
break
# yield the buffer
rand.shuffle(buffer)
for buffer_value in buffer:
if n <= 0:
break
if type(buffer_value) is _Placeholder:
buffer_value = reader.get(buffer_value, yield_key=yield_key, raw=raw)
yield buffer_value
n -= 1
[docs] def find_key(self, index):
raise NotImplementedError("QuasiShuffler does not implement random access")
[docs] def find_index(self, key):
raise NotImplementedError("QuasiShuffler does not implement random access")
[docs] def get(self, index, yield_key=False, raw=False, copy=True):
raise NotImplementedError("QuasiShuffler does not implement random access")
[docs] def slice(self, start, stop=None, yield_key=False, raw=False, copy=True):
raise NotImplementedError("QuasiShuffler does not implement random access")
def _repeater_index(i, length, total, is_stop=False):
if i < -total or i >= total + is_stop:
raise IndexError(f'index {i} out of range for length {total} reader')
if i < 0:
i += total
return i % length, i // length
[docs]class Repeater(Augment):
"""
Repeat a :py:class:`Reader <datadings.reader.reader.Reader`
a fixed number of times.
Note:
``find_index`` returns the first occurrence.
"""
def __init__(self, reader, times):
super().__init__(reader)
self.times = times
def __len__(self):
return self._len * self.times
def __contains__(self, key):
return key in self._reader
[docs] def find_key(self, index):
return self._reader.find_key(index)
[docs] def find_index(self, key):
return self._reader.find_index(key)
[docs] def get(self, index, yield_key=False, raw=False, copy=True):
total = self._len * self.times
index, _ = _repeater_index(index, self._len, total)
return self._reader.get(index=index, yield_key=yield_key, raw=raw, copy=copy)
[docs] def slice(self, start, stop=None, yield_key=False, raw=False, copy=True):
yield from self.iter(start, stop=stop, yield_key=yield_key, raw=raw, copy=copy)
def _iter_impl(
self,
start,
stop,
yield_key=False,
raw=False,
copy=True,
chunk_size=16,
):
total = self._len * self.times
# actual start/stop index in the wrapped reader
# and which number repetition start/stop fall into
start, start_rep = _repeater_index(start, self._len, total)
stop, stop_rep = _repeater_index(stop, self._len, total, is_stop=True)
# for the simplest case, both start and stop are in the same repetition
if start_rep == stop_rep:
return self._reader.iter(
start,
stop=stop,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
# if they're not, we first need to yield from the start index
if start_rep != stop_rep:
yield from self._reader.iter(
start,
stop=None,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
# this is the number of full iterations between start and stop
for _ in range(stop_rep - start_rep - 1):
yield from self._reader.iter(
0,
stop=self._len,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
# and the remaining items until stop index
if start_rep != stop_rep:
yield from self._reader.iter(
0,
stop=stop,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
def _cycler_index(i, length):
if i < 0:
raise IndexError(f'Cycler does not support negative index {i}')
return i % length, i // length
[docs]class Cycler(Augment):
"""
Infinitely cycle a :py:class:`Reader <datadings.reader.reader.Reader`.
Iterators can be requested with any start/stop index.
Large indexes simply wrap around.
"""
[docs] def iter(
self,
start=None,
stop=None,
yield_key=False,
raw=False,
copy=True,
chunk_size=16,
):
if start is None:
start = 0
else:
start, start_rep = _cycler_index(start, self._len)
# loop forever from start index
if stop is None:
yield from self._reader.iter(
start,
stop=self._len,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
while True:
yield from self._reader.iter(
0,
stop=self._len,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
else:
stop, stop_rep = _cycler_index(stop, self._len)
# for the simplest case, both start and stop are in the same repetition
if start_rep == stop_rep:
return self._reader.iter(
start,
stop=stop,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
# if they're not, we first need to yield from the start index
if start_rep != stop_rep:
yield from self._reader.iter(
start,
stop=None,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
# this is the number of full iterations between start and stop
for _ in range(stop_rep - start_rep - 1):
yield from self._reader.iter(
0,
stop=self._len,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
# and the remaining items until stop index
if start_rep != stop_rep:
yield from self._reader.iter(
0,
stop=stop,
yield_key=yield_key,
raw=raw,
copy=copy,
chunk_size=chunk_size,
)
class Slicer:
"""
A wrapper for :python:`itertools.islice`.
``iter(slicer)`` is equivalent to ``itertools.islice(iterable, length)``.
Warning:
``len(slicer) == length``,
even though there is no way to actually guarantee this.
"""
def __init__(self, iterable, length):
self.iterable = iterable
self.length = length
def __len__(self):
return self.length
def __iter__(self):
return it.islice(self.iterable, self.length)