datadings.torch package

class datadings.torch.Compose(*transforms, prefix='')[source]

Bases: object

Compose a sequence of transform functions. Functions must accept the intended value from samples as first argument. They may have an arbitrary number of positional and keyword arguments.

Example usage with Dataset:

import random
from datadings.torch import Compose
from datadings.torch import Dataset
from datadings.reader import ListReader

def add(v, number):
    return v + number

def sub(x, value):
    return x - value

def rng(_):
    return {
        'number': random.randrange(1, 10),
        'value': random.randrange(1, 10),
    }

samples = [{'a': 0, 'b': 0, 'c': 0} for _ in range(10)]
reader = ListReader(samples)
transforms = {
    'a': Compose(add),
    'b': Compose(sub),
    'c': Compose((add, sub)),
}
dataset = Dataset(reader, transforms=transforms, rng=rng)
for i in range(len(dataset)):
    print(dataset[i])
Parameters:
  • transforms – sequence of transform functions, either one iterable or varargs

  • prefix – string prefix for parameter names, i.e., if a function normally requires parameter size given prefix='mask_' the parameter 'mask_size'

class datadings.torch.CompressedToPIL[source]

Bases: object

Compatible torchvision transform that takes a compressed image as bytes (or similar) and returns a PIL image.

class datadings.torch.Dataset(*args: Any, **kwargs: Any)[source]

Bases: DatasetBase, Dataset

Implementation of torch.utils.data.Dataset.

Warning

Dataset can be significantly slower than IterableDataset. If shuffling is necessary consider using QuasiShuffler instead.

Example usage with the PyTorch DataLoader:

path = '.../train.msgpack'
batch_size = 256
reader = MsgpackReader(path)
transforms = {'image': Compose(
    CompressedToPIL(),
    ...,
    ToTensor(),
)}
ds = Dataset(reader, transforms=transforms)
train = DataLoader(dataset=ds, batch_size=batch_size)
for epoch in range(3):
    for x, y in dict2tuple(tqdm(train)):
        pass
Parameters:
  • reader – the datadings reader instance

  • transforms – Transforms applied to samples before they are returned. Either a dict of transform functions or callable with signature f(sample: dict) -> dict that is applied directly to samples. In the dict form keys correspond to keys in the sample and values are callables with signature t(value: any, params: dict) -> any (e.g., an instance of Compose) with params the value returned by the rng callable.

  • rng – callable with signature rng(params: dict) -> dict that returns a dict of parameters applied to transforms

class datadings.torch.DatasetBase(reader: Reader, transforms=None, rng=None)[source]

Bases: object

class datadings.torch.IterableDataset(*args: Any, **kwargs: Any)[source]

Bases: DatasetBase, IterableDataset

Implementation of torch.utils.data.IterableDataset to use with datadings readers.

With distributed training the reader is divided into world_size * num_workers shards. Each dataloader worker of each rank iterates over a different shard. The final batch delivered by a worker may be smaller than the batch size if the length of the reader is not divisible by batch_size * num_shards.

Note

Set persistent_workers=True for the DataLoader to let the dataset object track the current epoch. It then cycles through shards This makes ranks cycle through shards of the dataset Without this option torch may create new worker processes at any time, which resets the dataset to its initial state.

Warning

Raises RuntimeError if 0 < len(shard) % batch_size < 1, since this may lead to an uneven number of batches generated by each worker. This can lead to crashes if it happens between rank workers, or deadlock if ranks receive different a number of batches. Change num_workers, batch_size, or world_size to avoid this.

Example usage with the PyTorch DataLoader:

path = '.../train.msgpack'
batch_size = 256
reader = MsgpackReader(path)
transforms = {'image': Compose(
    CompressedToPIL(),
    ...,
    ToTensor(),
)}
ds = IterableDataset(
    reader,
    transforms=transforms,
    batch_size=batch_size,
)
train = DataLoader(
    dataset=ds,
    batch_size=batch_size,
    num_workers=4,
    persistent_workers=True,
)
for epoch in range(3):
    print('Epoch', epoch)
    for x, y in dict2tuple(tqdm(train)):
        pass
Parameters:
  • reader – the datadings reader instance

  • transforms – Transforms applied to samples before they are returned. Either a dict of transform functions or callable with signature f(sample: dict) -> dict that is applied directly to samples. In the dict form keys correspond to keys in the sample and values are callables with signature t(value: any, params: dict) -> any (e.g., an instance of Compose) with params the value returned by the rng callable.

  • rng – callable with signature rng(params: dict) -> dict that returns a dict of parameters applied to transforms

  • batch_size – same batch size as given to the DataLoader

  • epoch – starting epoch, zero indexed; only relevant when resuming

  • copy – see datadings.reader.reader.Reader.iter()

  • chunk_size – see datadings.reader.reader.Reader.iter()

  • group – distributed process group to use (if not using the default)

datadings.torch.dict2tuple(it, keys=('image', 'label'))[source]

Utility function that extracts and yields the given keys from each sample in the given iterator.

datadings.torch.no_rng(_)[source]