Source code for datadings.torch

from io import BytesIO
from typing import Iterable
from typing import Callable
import inspect
import warnings

from PIL import Image
from simplejpeg import decode_jpeg
from torch.utils.data import Dataset as _Dataset
from torch.utils.data import IterableDataset as _IterableDataset
from torch.utils.data import get_worker_info
from torch.distributed import is_initialized
from torch.distributed import get_rank
from torch.distributed import get_world_size

from ..reader.reader import Reader


def _noop(sample):
    return sample


def _getargs(f):
    spec = inspect.getfullargspec(f)
    args = spec.args + spec.kwonlyargs
    # f must be callable
    if isinstance(f, Callable):
        # remove first arg for callables that are not functions
        # and warn if name of the arg is not self
        if not inspect.isfunction(f):
            if args:
                if args[0] != 'self':
                    warnings.warn(
                        f"expected first argument of {f!r} to be 'self', "
                        f"but found {args[0]!r}, arguments given to this "
                        "transform may be incorrect"
                    )
                args = args[1:]
    else:
        raise ValueError(f'transform must be callable, not {f!r}')
    # if there are args, return all but the first arg
    # warn if there are also varargs, which will be ignored
    if args:
        if spec.varargs:
            warnings.warn(f'varargs are ignored for {f!r}')
        return args[1:]
    # if there are no args check if there are varargs instead
    # this is the case for non-functional torchvision transforms
    # warn that this transform can only accept values and no parameters
    else:
        if spec.varargs:
            warnings.warn(f"{f!r} only accepts varargs so "
                          "it will only receive sample values")
        else:
            raise ValueError('transforms must accept at least one argument '
                             f'but {f!r} accepts none')
        return []


[docs]class Compose: """ Compose a sequence of transform functions. Functions must accept the intended value from samples as first argument. They may have an arbitrary number of positional and keyword arguments. Example usage with :py:class:`.Dataset`:: import random from datadings.torch import Compose from datadings.torch import Dataset from datadings.reader import ListReader def add(v, number): return v + number def sub(x, value): return x - value def rng(_): return { 'number': random.randrange(1, 10), 'value': random.randrange(1, 10), } samples = [{'a': 0, 'b': 0, 'c': 0} for _ in range(10)] reader = ListReader(samples) transforms = { 'a': Compose(add), 'b': Compose(sub), 'c': Compose((add, sub)), } dataset = Dataset(reader, transforms=transforms, rng=rng) for i in range(len(dataset)): print(dataset[i]) Parameters: transforms: sequence of transform functions, either one iterable or varargs prefix: string prefix for parameter names, i.e., if a function normally requires parameter ``size`` given ``prefix='mask_'`` the parameter ``'mask_size'`` """ def __init__(self, *transforms, prefix=''): if len(transforms) == 1 and isinstance(transforms[0], Iterable): transforms = transforms[0] self.transforms = tuple(transforms) self.param_names = [ tuple(prefix + arg for arg in _getargs(f)) for f in transforms ] self.prefix = prefix def __call__(self, value, params): try: for f, ps in zip(self.transforms, self.param_names): value = f(value, **{p: params[p] for p in ps if p in params}) return value except TypeError as e: raise KeyError('Required parameters for transform function are ' 'missing. Incomplete rng function? Missing constants? ' + str(e))
[docs]def no_rng(_): return {}
[docs]class DatasetBase: def __init__( self, reader: Reader, transforms=None, rng=None, ): self.reader = reader self._rng = rng or no_rng self._transforms = transforms if transforms is not None: if isinstance(transforms, dict): self.transform = self._transform else: self.transform = transforms else: self.transform = _noop def _transform(self, sample): params = self._rng(sample) sample['__params__'] = params for k, f in self._transforms.items(): sample[k] = f(sample[k], params) return sample def __len__(self): return len(self.reader)
[docs]class Dataset(DatasetBase, _Dataset): """ Implementation of ``torch.utils.data.Dataset``. .. warning:: :py:class:`~datadings.torch.Dataset` can be significantly slower than :py:class:`~datadings.torch.IterableDataset`. If shuffling is necessary consider using :py:class:`~datadings.reader.augment.QuasiShuffler` instead. Example usage with the PyTorch ``DataLoader``:: path = '.../train.msgpack' batch_size = 256 reader = MsgpackReader(path) transforms = {'image': Compose( CompressedToPIL(), ..., ToTensor(), )} ds = Dataset(reader, transforms=transforms) train = DataLoader(dataset=ds, batch_size=batch_size) for epoch in range(3): for x, y in dict2tuple(tqdm(train)): pass Parameters: reader: the datadings reader instance transforms: Transforms applied to samples before they are returned. Either a dict of transform functions or callable with signature ``f(sample: dict) -> dict`` that is applied directly to samples. In the dict form keys correspond to keys in the sample and values are callables with signature ``t(value: any, params: dict) -> any`` (e.g., an instance of :py:class:`.Compose`) with ``params`` the value returned by the ``rng`` callable. rng: callable with signature ``rng(params: dict) -> dict`` that returns a dict of parameters applied to transforms """ def __getitem__(self, index): return self.transform(self.reader.get(index))
# noinspection PyAbstractClass
[docs]class IterableDataset(DatasetBase, _IterableDataset): """ Implementation of ``torch.utils.data.IterableDataset`` to use with datadings readers. With distributed training the reader is divided into ``world_size * num_workers`` shards. Each dataloader worker of each rank iterates over a different shard. The final batch delivered by a worker may be smaller than the batch size if the length of the reader is not divisible by ``batch_size * num_shards``. .. note:: Set ``persistent_workers=True`` for the ``DataLoader`` to let the dataset object track the current epoch. It then cycles through shards This makes ranks cycle through shards of the dataset Without this option torch may create new worker processes at any time, which resets the dataset to its initial state. .. warning:: Raises ``RuntimeError`` if ``0 < len(shard) % batch_size < 1``, since this may lead to an uneven number of batches generated by each worker. This can lead to crashes if it happens between rank workers, or deadlock if ranks receive different a number of batches. Change ``num_workers``, ``batch_size``, or ``world_size`` to avoid this. Example usage with the PyTorch ``DataLoader``:: path = '.../train.msgpack' batch_size = 256 reader = MsgpackReader(path) transforms = {'image': Compose( CompressedToPIL(), ..., ToTensor(), )} ds = IterableDataset( reader, transforms=transforms, batch_size=batch_size, ) train = DataLoader( dataset=ds, batch_size=batch_size, num_workers=4, persistent_workers=True, ) for epoch in range(3): print('Epoch', epoch) for x, y in dict2tuple(tqdm(train)): pass Parameters: reader: the datadings reader instance transforms: Transforms applied to samples before they are returned. Either a dict of transform functions or callable with signature ``f(sample: dict) -> dict`` that is applied directly to samples. In the dict form keys correspond to keys in the sample and values are callables with signature ``t(value: any, params: dict) -> any`` (e.g., an instance of :py:class:`.Compose`) with ``params`` the value returned by the ``rng`` callable. rng: callable with signature ``rng(params: dict) -> dict`` that returns a dict of parameters applied to transforms batch_size: same batch size as given to the ``DataLoader`` epoch: starting epoch, zero indexed; only relevant when resuming copy: see :py:meth:`datadings.reader.reader.Reader.iter` chunk_size: see :py:meth:`datadings.reader.reader.Reader.iter` group: distributed process group to use (if not using the default) """ def __init__( self, reader: Reader, transforms=None, rng=None, batch_size=None, epoch=0, copy=True, chunk_size=16, group=None ): DatasetBase.__init__(self, reader, transforms, rng) if is_initialized(): self.rank = get_rank(group) self.world_size = get_world_size(group) else: self.rank = 0 self.world_size = 1 self.batch_size = batch_size self.epoch = epoch self.copy = copy self.chunk_size = chunk_size def _start_stop(self): # check which worker this process is info = get_worker_info() if info is None: num_workers, worker_index = 1, 0 else: num_workers, worker_index = info.num_workers, info.id # calculate the number and size of shards n = len(self.reader) num_shards = self.world_size * num_workers shard_iters = n / num_shards # check if given batch_size could lead to empty last batch if self.batch_size is not None: overhang = shard_iters % self.batch_size if 0 < overhang < 1: raise RuntimeError( f"len(shard) % batch_size = {overhang}, " "so last batch may be empty; " "change num_workers, batch_size, or world_size" ) # rank is offset by current epoch # this makes ranks cycle through shards during training index = (self.rank + self.epoch) * num_workers + worker_index index %= num_shards start = max(0, int(round(shard_iters * index))) # last shard should include at last sample if index + 1 == num_shards: stop = n else: stop = min(n, int(round(shard_iters * (index + 1)))) return start, stop def __len__(self): start, stop = self._start_stop() return stop - start def __iter__(self): # create the iterator for the current shard start, stop = self._start_stop() it = self.reader.iter(start, stop, copy=self.copy, chunk_size=self.chunk_size) # advance epoch by one self.epoch += 1 # yield transformed samples with self.reader: yield from map(self.transform, it)
[docs]class CompressedToPIL: """ Compatible torchvision transform that takes a compressed image as bytes (or similar) and returns a PIL image. """ def __call__(self, buf): try: img = Image.fromarray(decode_jpeg(buf), 'RGB') except ValueError: img = Image.open(BytesIO(buf)).convert('RGB') return img def __repr__(self): return self.__class__.__name__ + '()'
[docs]def dict2tuple(it, keys=('image', 'label')): """ Utility function that extracts and yields the given keys from each sample in the given iterator. """ for sample in it: yield tuple(sample[k] for k in keys)