PyTorch integration

Warning

This functionality is highly experimental and subject to change in future version!

datadings provides experimental integration with PyTorch. There are two options:

  1. Dataset

  2. IterableDataset

These implement the respective PyTorch dataset classes and work as expected with the PyTorch DataLoader.

Note

persistent_workers=True must be used to let IterableDataset track the current epoch.

Warning

Dataset can be significantly slower than IterableDataset. If shuffling is necessary consider using QuasiShuffler instead.

Example usage with the PyTorch DataLoader:

from datadings.reader import MsgpackReader
from datadings.torch import IterableDataset
from datadings.torch import CompressedToPIL
from datadings.torch import dict2tuple
from datadings.torch import Compose

from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.transforms import RandomResizedCrop
from torchvision.transforms import RandomHorizontalFlip


def main():
    path = '.../train.msgpack'
    batch_size = 256
    transforms = {'image': Compose(
        CompressedToPIL(),
        RandomResizedCrop((224, 224)),
        RandomHorizontalFlip(),
        ToTensor(),
    )}
    reader = MsgpackReader(path)
    ds = IterableDataset(
        reader,
        transforms=transforms,
        batch_size=batch_size,
    )
    train = DataLoader(
        dataset=ds,
        batch_size=batch_size,
        num_workers=4,
        persistent_workers=True,
    )
    for epoch in range(3):
        print('Epoch', epoch)
        for x, y in dict2tuple(tqdm(train)):
            pass


if __name__ == "__main__":
    main()

In our example transforms is a dictionary with one key 'image'. That means the given transformation is applied to the value with this key. You can add more keys and transforms to apply functions to different keys.

Note

There will be warnings that transforms only accept varargs when using non-functional torchvision transforms due to their opaque call signatures. This is fine, since these transforms only need the value as input. Other transforms may not work though.

If you need to share randomness between transformations (e.g. to synchronize augmentation steps between image and mask in semantic segmentation) you can use functions that accept randomness as parameters, like functional transforms from torchvision. Datasets accept a callable rng parameter with signature rng(sample: dict) -> dict. sample is the sample that is going to be transformed and the returned dictionary must contain all positional parameters required by the transform functions. The Compose object reads the function signatures of your transforms to determine which parameters are required. A minimal example for how this works:

import random
from datadings.torch import Compose
from datadings.torch import Dataset
from datadings.reader import ListReader

def add(v, number):
    return v + number

def sub(x, value):
    return x - value

def rng(_):
    return {
        'number': random.randrange(1, 10),
        'value': random.randrange(1, 10),
    }

samples = [{'a': 0, 'b': 0, 'c': 0} for _ in range(10)]
reader = ListReader(samples)
transforms = {
    'a': Compose(add),
    'b': Compose(sub),
    'c': Compose((add, sub)),
}
dataset = Dataset(reader, transforms=transforms, rng=rng)
for i in range(len(dataset)):
    print(dataset[i])

Note

You can use functools.partial() (or similar) to set constant values for parameters and change defaults for keyword arguments instead of including them in your rng dictionary.

Warning

Transform functions will receive the same value if they share parameter names. If this is not intended you must wrap one of those functions in another function with and change on of the parameter names.

Alternatively transforms may be a custom function with signature t(sample: dict) -> dict. This allows you to use multiple values from the sample for a transform, create new values based on the sample, etc.